Skip to content

Commit

Permalink
Add paddle.sparse and three Sparse API (PaddlePaddle#41276)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangkaihuo committed Apr 14, 2022
1 parent 921a6fb commit d62690c
Show file tree
Hide file tree
Showing 10 changed files with 461 additions and 23 deletions.
1 change: 1 addition & 0 deletions python/paddle/__init__.py
Expand Up @@ -75,6 +75,7 @@
import paddle.reader # noqa: F401
import paddle.static # noqa: F401
import paddle.vision # noqa: F401
import paddle.sparse # noqa: F401

from .tensor.attribute import is_complex # noqa: F401
from .tensor.attribute import is_integer # noqa: F401
Expand Down
21 changes: 11 additions & 10 deletions python/paddle/fluid/tests/unittests/test_sparse_activation_op.py
Expand Up @@ -16,24 +16,25 @@
import unittest
import numpy as np
import paddle
from paddle import _C_ops
from paddle.fluid.framework import _test_eager_guard


class TestSparseActivation(unittest.TestCase):
def test_sparse_relu(self):
with _test_eager_guard():
x = [[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]]
dense_x = paddle.to_tensor(x, dtype='float32')
dense_shape = [3, 4]
stop_gradient = True
dense_x = paddle.to_tensor(x, dtype='float32', stop_gradient=False)
sparse_dim = 2
sparse_coo_x = dense_x.to_sparse_coo(sparse_dim)
#TODO(zhangkaihuo): change to test the corresponding API: paddle.sparse.relu(sparse_coo_x)
sparse_act_out = _C_ops.final_state_sparse_relu(sparse_coo_x)
correct_result = [0, 2, 0, 4, 5]
actual_result = sparse_act_out.non_zero_elements().numpy()
assert np.array_equal(correct_result, actual_result)
sparse_x = dense_x.to_sparse_coo(sparse_dim)
sparse_relu = paddle.sparse.ReLU()
sparse_out = sparse_relu(sparse_x)
dense_relu = paddle.nn.ReLU()
#TODO: replace non_zero_elements() as values()
dense_out = dense_relu(sparse_x.non_zero_elements())
actual_result = sparse_out.non_zero_elements().numpy()
assert np.array_equal(dense_out.numpy(), actual_result)
dense_out.backward(dense_out)
sparse_out.backward(sparse_out)


if __name__ == "__main__":
Expand Down
99 changes: 87 additions & 12 deletions python/paddle/fluid/tests/unittests/test_sparse_utils_op.py
Expand Up @@ -16,27 +16,37 @@
import unittest
import numpy as np
import paddle
from paddle import _C_ops
from paddle.fluid import core
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard


class TestSparseUtils(unittest.TestCase):
def test_create_sparse_coo_tensor(self):
class TestSparseCreate(unittest.TestCase):
def test_create_coo_by_tensor(self):
with _test_eager_guard():
non_zero_indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
non_zero_elements = [1, 2, 3, 4, 5]
dense_shape = [3, 4]
dense_indices = paddle.to_tensor(non_zero_indices)
dense_elements = paddle.to_tensor(
non_zero_elements, dtype='float32')
stop_gradient = False
coo = core.eager.sparse_coo_tensor(dense_indices, dense_elements,
dense_shape, stop_gradient)
coo = paddle.sparse.sparse_coo_tensor(
dense_indices, dense_elements, dense_shape, stop_gradient=False)
assert np.array_equal(non_zero_indices,
coo.non_zero_indices().numpy())
assert np.array_equal(non_zero_elements,
coo.non_zero_elements().numpy())

def test_create_coo_by_np(self):
with _test_eager_guard():
indices = [[0, 1, 2], [1, 2, 0]]
values = [1.0, 2.0, 3.0]
dense_shape = [2, 3]
coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
print(coo)
assert np.array_equal(indices, coo.non_zero_indices().numpy())
assert np.array_equal(values, coo.non_zero_elements().numpy())

def test_create_sparse_csr_tensor(self):
def test_create_csr_by_tensor(self):
with _test_eager_guard():
non_zero_crows = [0, 2, 3, 5]
non_zero_cols = [1, 3, 2, 0, 1]
Expand All @@ -47,12 +57,77 @@ def test_create_sparse_csr_tensor(self):
dense_elements = paddle.to_tensor(
non_zero_elements, dtype='float32')
stop_gradient = False
csr = core.eager.sparse_csr_tensor(dense_crows, dense_cols,
dense_elements, dense_shape,
stop_gradient)

csr = paddle.sparse.sparse_csr_tensor(
dense_crows,
dense_cols,
dense_elements,
dense_shape,
stop_gradient=stop_gradient)
print(csr)

def test_create_csr_by_np(self):
with _test_eager_guard():
crows = [0, 2, 3, 5]
cols = [1, 3, 2, 0, 1]
values = [1, 2, 3, 4, 5]
dense_shape = [3, 4]
csr = paddle.sparse.sparse_csr_tensor(crows, cols, values,
dense_shape)
assert np.array_equal(crows, csr.non_zero_crows().numpy())
assert np.array_equal(cols, csr.non_zero_cols().numpy())
assert np.array_equal(values, csr.non_zero_elements().numpy())

def test_place(self):
with _test_eager_guard():
place = core.CPUPlace()
indices = [[0, 1], [0, 1]]
values = [1.0, 2.0]
dense_shape = [2, 2]
coo = paddle.sparse.sparse_coo_tensor(
indices, values, dense_shape, place=place)
assert coo.place.is_cpu_place()
assert coo.non_zero_elements().place.is_cpu_place()
assert coo.non_zero_indices().place.is_cpu_place()

crows = [0, 2, 3, 5]
cols = [1, 3, 2, 0, 1]
values = [1.0, 2.0, 3.0, 4.0, 5.0]
csr = paddle.sparse.sparse_csr_tensor(
crows, cols, values, [3, 5], place=place)
assert csr.place.is_cpu_place()
assert csr.non_zero_crows().place.is_cpu_place()
assert csr.non_zero_cols().place.is_cpu_place()
assert csr.non_zero_elements().place.is_cpu_place()

def test_dtype(self):
with _test_eager_guard():
indices = [[0, 1], [0, 1]]
values = [1.0, 2.0]
dense_shape = [2, 2]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
coo = paddle.sparse.sparse_coo_tensor(
indices, values, dense_shape, dtype='float64')
assert coo.dtype == paddle.float64

crows = [0, 2, 3, 5]
cols = [1, 3, 2, 0, 1]
values = [1.0, 2.0, 3.0, 4.0, 5.0]
csr = paddle.sparse.sparse_csr_tensor(
crows, cols, values, [3, 5], dtype='float16')
assert csr.dtype == paddle.float16

def test_create_coo_no_shape(self):
with _test_eager_guard():
indices = [[0, 1], [0, 1]]
values = [1.0, 2.0]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
coo = paddle.sparse.sparse_coo_tensor(indices, values)
assert [2, 2] == coo.shape


class TestSparseConvert(unittest.TestCase):
def test_to_sparse_coo(self):
with _test_eager_guard():
x = [[0, 1, 0, 2], [0, 0, 3, 0], [4, 5, 0, 0]]
Expand Down
19 changes: 19 additions & 0 deletions python/paddle/sparse/__init__.py
@@ -0,0 +1,19 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .creation import sparse_coo_tensor
from .creation import sparse_csr_tensor
from .layer.activation import ReLU

__all__ = ['sparse_coo_tensor', 'sparse_csr_tensor', 'ReLU']
191 changes: 191 additions & 0 deletions python/paddle/sparse/creation.py
@@ -0,0 +1,191 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle import _C_ops
from ..framework import core, dygraph_only
from ..tensor import to_tensor
from ..tensor import max
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype

__all__ = [
'sparse_coo_tensor',
'sparse_csr_tensor',
]


def _handle_dtype(data, dtype):
if dtype:
if convert_dtype(dtype) != convert_dtype(data.dtype):
return data.astype(convert_dtype(dtype))
return data


def _infer_dense_shape(indices):
assert len(indices.shape) == 2
lens = max(indices, axis=1)
lens = lens + 1
return list(lens.numpy())


@dygraph_only
def sparse_coo_tensor(indices,
values,
shape=None,
dtype=None,
place=None,
stop_gradient=True):
r"""
Constructs a sparse ``paddle.Tensor`` in coordinate format according to the indices
and values of the specified non-zero elements.
Args:
indices(list|tuple|ndarray|Tensor): the indices of non-zero elements.
Can be a list, tuple, numpy\.ndarray, paddle\.Tensor. The indices must be 2-D.
values(list|tuple|ndarray|Tensor): Initial values for the tensor.
Can be a scalar, list, tuple, numpy\.ndarray, paddle\.Tensor.
shape(list|tuple, optional): The shape of the sparse tensor also represents the shape of
original dense tensor. If not provided the smallest shape will be inferred to
hold all elements.
dtype(str|np.dtype, optional): The desired data type of returned tensor. Can be 'bool' , 'float16' ,
'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' , 'uint8',
'complex64' , 'complex128'. Default: None, infers dtype from ``data``
except for python float number which gets dtype from ``get_default_type`` .
place(CPUPlace|CUDAPinnedPlace|CUDAPlace|str, optional): The place to allocate Tensor. Can be
CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. If ``place`` is
string, It can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, where ``x`` is the index of the GPUs.
stop_gradient(bool, optional): Whether to block the gradient propagation of Autograd. Default: True.
Returns:
Tensor: A Tensor constructed from ``indices`` and ``values`` .
Raises:
TypeError: If the data type of ``values`` is not list, tuple, numpy.ndarray, paddle.Tensor
ValueError: If ``values`` is tuple|list, it can't contain nested tuple|list with different lengths , such as: [[1, 2], [3, 4, 5]]. If the ``indices`` is not a 2-D.
TypeError: If ``dtype`` is not bool, float16, float32, float64, int8, int16, int32, int64, uint8, complex64, complex128
ValueError: If ``place`` is not paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace or specified pattern string.
Examples:
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard():
indices = [[0, 1, 2], [1, 2, 0]]
values = [1.0, 2.0, 3.0]
dense_shape = [2, 3]
coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
# print(coo)
# Tensor(shape=[2, 3], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True,
# indices=[[0, 1, 2],
# [1, 2, 0]],
# values=[1., 2., 3.])
"""

if not isinstance(indices, core.eager.Tensor):
indices = to_tensor(
indices, dtype=None, place=place, stop_gradient=True)
if not isinstance(values, core.eager.Tensor):
values = to_tensor(values, dtype, place, stop_gradient)
if len(indices.shape) != 2:
raise ValueError("'indices' must be 2-D.")
if place is not None:
indices = indices._copy_to(place, False)
values = values._copy_to(place, False)
values = _handle_dtype(values, dtype)
if shape is None:
shape = _infer_dense_shape(indices)
return core.eager.sparse_coo_tensor(indices, values, shape, stop_gradient)


#TODO: need to support shape is None
@dygraph_only
def sparse_csr_tensor(crows,
cols,
values,
shape,
dtype=None,
place=None,
stop_gradient=True):
r"""
Constructs a sparse ``paddle.Tensor`` in CSR(Compressed Sparse Row) format according to the
``crows``, ``cols`` and ``values``.
Args:
crows(list|tuple|ndarray|Tensor): 1-D array, each element in the rows represents the
starting position of the first non-zero element of each row in values.
Can be a list, tuple, numpy\.ndarray, paddle\.Tensor.
cols(list|tuple|ndarray|Tensor): 1-D array, the column of non-zero elements.
Can be a list, tuple, numpy\.ndarray, paddle\.Tensor.
values(list|tuple|ndarray|Tensor): 1-D array, the non-zero elements.
Can be a scalar, list, tuple, numpy\.ndarray, paddle\.Tensor.
shape(list|tuple, optional): The shape of the sparse tensor also represents the shape of
original dense tensor.
hold all elements.
dtype(str|np.dtype, optional): The desired data type of returned tensor. Can be 'bool' , 'float16' ,
'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' , 'uint8',
'complex64' , 'complex128'. Default: None, infers dtype from ``data``
except for python float number which gets dtype from ``get_default_type`` .
place(CPUPlace|CUDAPinnedPlace|CUDAPlace|str, optional): The place to allocate Tensor. Can be
CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. If ``place`` is
string, It can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, where ``x`` is the index of the GPUs.
stop_gradient(bool, optional): Whether to block the gradient propagation of Autograd. Default: True.
Returns:
Tensor: A Tensor constructed from ``crows``, ``cols`` and ``values`` .
Raises:
TypeError: If the data type of ``values`` is not list, tuple, numpy.ndarray, paddle.Tensor
ValueError: If ``values`` is tuple|list, it can't contain nested tuple|list with different lengths , such as: [[1, 2], [3, 4, 5]]. If the ``crow``, ``cols`` and ``values`` is not a 2-D.
TypeError: If ``dtype`` is not bool, float16, float32, float64, int8, int16, int32, int64, uint8, complex64, complex128
ValueError: If ``place`` is not paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace or specified pattern string.
Examples:
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard():
crows = [0, 2, 3, 5]
cols = [1, 3, 2, 0, 1]
values = [1, 2, 3, 4, 5]
dense_shape = [3, 4]
csr = paddle.sparse.sparse_csr_tensor(crows, cols, values, dense_shape)
# print(csr)
# Tensor(shape=[3, 4], dtype=paddle.int64, place=Place(gpu:0), stop_gradient=True,
# crows=[0, 2, 3, 5],
# cols=[1, 3, 2, 0, 1],
# values=[1, 2, 3, 4, 5])
"""
if not isinstance(crows, core.eager.Tensor):
crows = to_tensor(crows, dtype=None, place=place, stop_gradient=True)
if not isinstance(cols, core.eager.Tensor):
cols = to_tensor(cols, dtype=None, place=place, stop_gradient=True)
if not isinstance(values, core.eager.Tensor):
values = to_tensor(values, dtype, place, stop_gradient)
if len(crows.shape) != 1 or len(cols.shape) != 1 or len(values.shape) != 1:
raise ValueError(
"SparseCsrTensor only support 2-D or 3-D matrix. The 'crows', 'cols' and 'values' must be 1-D."
)

if place is not None:
crows = crows._copy_to(place, False)
cols = cols._copy_to(place, False)
values = values._copy_to(place, False)
values = _handle_dtype(values, dtype)
return core.eager.sparse_csr_tensor(crows, cols, values, shape,
stop_gradient)

0 comments on commit d62690c

Please sign in to comment.