Skip to content

Commit

Permalink
update test, implement api, fix sqrt grad
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaishaonvjituizi committed Apr 21, 2022
1 parent 7e5f102 commit 71864fd
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 62 deletions.
30 changes: 15 additions & 15 deletions paddle/phi/kernels/sparse/utils.h
Expand Up @@ -60,52 +60,53 @@
\
template <typename T, typename Context> \
void SparseCoo##dense_kernel_func(const Context& dev_ctx, \
const SparseCooTensor& x, \
const SparseCooTensor& x_or_out, \
const SparseCooTensor& out_grad, \
SparseCooTensor* x_grad) { \
DenseTensor non_zero_indices = \
phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_indices()); \
phi::EmptyLike<T, Context>(dev_ctx, x_or_out.non_zero_indices()); \
DenseTensor non_zero_elements = \
phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_elements()); \
phi::EmptyLike<T, Context>(dev_ctx, x_or_out.non_zero_elements()); \
phi::Copy(dev_ctx, \
x.non_zero_indices(), \
x_or_out.non_zero_indices(), \
dev_ctx.GetPlace(), \
false, \
&non_zero_indices); \
phi::dense_kernel_func<T, Context>(dev_ctx, \
x.non_zero_elements(), \
x_or_out.non_zero_elements(), \
out_grad.non_zero_elements(), \
&non_zero_elements); \
x_grad->SetMember(non_zero_indices, non_zero_elements, x.dims(), true); \
x_grad->SetMember( \
non_zero_indices, non_zero_elements, x_or_out.dims(), true); \
} \
\
template <typename T, typename Context> \
void SparseCsr##dense_kernel_func(const Context& dev_ctx, \
const SparseCsrTensor& x, \
const SparseCsrTensor& x_or_out, \
const SparseCsrTensor& out_grad, \
SparseCsrTensor* out) { \
DenseTensor non_zero_crows = \
phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_crows()); \
phi::EmptyLike<T, Context>(dev_ctx, x_or_out.non_zero_crows()); \
DenseTensor non_zero_cols = \
phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_cols()); \
phi::EmptyLike<T, Context>(dev_ctx, x_or_out.non_zero_cols()); \
DenseTensor non_zero_elements = \
phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_elements()); \
phi::EmptyLike<T, Context>(dev_ctx, x_or_out.non_zero_elements()); \
phi::Copy(dev_ctx, \
x.non_zero_crows(), \
x_or_out.non_zero_crows(), \
dev_ctx.GetPlace(), \
false, \
&non_zero_crows); \
phi::Copy(dev_ctx, \
x.non_zero_cols(), \
x_or_out.non_zero_cols(), \
dev_ctx.GetPlace(), \
false, \
&non_zero_cols); \
phi::dense_kernel_func<T, Context>(dev_ctx, \
x.non_zero_elements(), \
x_or_out.non_zero_elements(), \
out_grad.non_zero_elements(), \
&non_zero_elements); \
out->SetMember( \
non_zero_crows, non_zero_cols, non_zero_elements, x.dims()); \
non_zero_crows, non_zero_cols, non_zero_elements, x_or_out.dims()); \
} \
} \
}
Expand Down Expand Up @@ -167,4 +168,3 @@
dense_kernel_func) \
DEFINE_SPARSE_UNARY_GRAD_KERNEL(dense_kernel_func) \
REGISTER_SPARSE_UNARY_KERNEL(kernel_name, dense_kernel_func)

82 changes: 40 additions & 42 deletions python/paddle/fluid/tests/unittests/test_sparse_activation_op.py
@@ -1,11 +1,11 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand All @@ -14,57 +14,55 @@

from __future__ import print_function
import unittest
from typing import Union, Callable
import numpy as np
import paddle
from paddle.fluid.framework import _test_eager_guard
from paddle import _C_ops


class TestSparseActivation(unittest.TestCase):
def test_sparse_relu(self):
def compare_with_dense(
self,
x,
to_sparse: Callable[[paddle.Tensor], paddle.Tensor],
dense_func: Callable[[paddle.Tensor], paddle.Tensor],
sparse_func: Callable[[paddle.Tensor], paddle.Tensor],
test_gradient: bool,
):
def tensor_equal(dense_tensor: paddle.Tensor, sparse_tensor: paddle.Tensor):
mask = ~np.isnan(dense_tensor.numpy())
return np.array_equal(dense_tensor.numpy()[mask], sparse_tensor.to_dense().numpy()[mask])

with _test_eager_guard():
x = [[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]]
dense_x = paddle.to_tensor(x, dtype="float32", stop_gradient=not test_gradient)

def dense_relu(x):
dense_x = paddle.to_tensor(
x, dtype='float32', stop_gradient=False)
dense_relu = paddle.nn.ReLU()
dense_out = dense_relu(dense_x)
dense_out.backward(dense_out)
return dense_out, dense_x.grad
sparse_x = to_sparse(dense_x)
sparse_out = sparse_func(sparse_x)

dense_x = paddle.to_tensor(x, dtype='float32', stop_gradient=False)
sparse_dim = 2
sparse_x = dense_x.to_sparse_coo(sparse_dim)
sparse_relu = paddle.sparse.ReLU()
sparse_out = sparse_relu(sparse_x)
sparse_out.backward(sparse_out)
dense_x = paddle.to_tensor(x, dtype="float32", stop_gradient=not test_gradient)
dense_out = dense_func(dense_x)

dense_out, dense_x_grad = dense_relu(x)
assert np.array_equal(dense_out.numpy(),
sparse_out.to_dense().numpy())
assert np.array_equal(dense_x_grad.numpy(),
sparse_x.grad.to_dense().numpy())
assert tensor_equal(dense_out, sparse_out)

def test_sparse_coo_sqrt(self):
with _test_eager_guard():
x = [[0, 4, 0, 2], [0, 0, 16, 0]]
dense_x = paddle.to_tensor(x, dtype='float32')
sparse_dim = 2
sparse_coo_x = dense_x.to_sparse_coo(sparse_dim)
sparse_act_out = _C_ops.final_state_sparse_coo_sqrt(sparse_coo_x)
correct_result = [2, np.sqrt(2), 4]
actual_result = sparse_act_out.non_zero_elements().numpy()
assert np.allclose(correct_result, actual_result)
if test_gradient:
dense_out.backward(dense_out)
sparse_out.backward(sparse_out)
assert tensor_equal(
dense_x.grad, sparse_x.grad
)

def test_sparse_csr_sqrt(self):
with _test_eager_guard():
x = [[0, 4, 0, 2], [0, 0, 0, 0], [0, 0, 16, 0]]
dense_x = paddle.to_tensor(x, dtype='float32')
sparse_coo_x = dense_x.to_sparse_csr()
sparse_act_out = _C_ops.final_state_sparse_csr_sqrt(sparse_coo_x)
correct_result = [2, np.sqrt(2), 4]
actual_result = sparse_act_out.non_zero_elements().numpy()
assert np.allclose(correct_result, actual_result)
def test_sparse_relu(self):
x = [[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]]
sparse_dim = 2
self.compare_with_dense(x, lambda x: x.to_sparse_coo(sparse_dim), lambda x: paddle.nn.ReLU()(x), lambda x: paddle.sparse.ReLU()(x), True)
self.compare_with_dense(x, lambda x: x.to_sparse_csr(), lambda x: paddle.nn.ReLU()(x), lambda x: paddle.sparse.ReLU()(x), False)

def test_sparse_sqrt(self):
x = [[0, 16, 0, 0], [0, 0, 0, 0], [0, 4, 2, 0]]
sparse_dim = 2
self.compare_with_dense(x, lambda x: x.to_sparse_coo(sparse_dim), paddle.sqrt, paddle.sparse.functional.sqrt, True)
self.compare_with_dense(x, lambda x: x.to_sparse_csr(), paddle.sqrt, paddle.sparse.functional.sqrt, False)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/sparse/functional/__init__.py
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

from .activation import relu # noqa: F401
from .activation import sqrt # noqa: F401
from .conv import conv3d # noqa: F401
from .conv import subm_conv3d # noqa: F401

__all__ = ['relu', 'conv3d', 'subm_conv3d']
__all__ = ['relu', 'conv3d', 'subm_conv3d', 'sqrt']
39 changes: 39 additions & 0 deletions python/paddle/sparse/functional/activation.py
Expand Up @@ -55,3 +55,42 @@ def relu(x, name=None):
else:
raise ValueError("Currently, sparse.relu only support the input of SparseCooTensor or SparseCsrTensor")


def sqrt(x, name=None):
"""
sparse sqrt activation.
.. math::
out = sqrt(x)
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import numpy as np
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard():
dense_x = paddle.to_tensor(np.array([-2, 0, 1]).astype('float32'))
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.sparse.functional.sqrt(sparse_x)
"""

assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode"

if x.is_sparse_coo():
return _C_ops.final_state_sparse_coo_sqrt(x)
elif x.is_sparse_csr():
return _C_ops.final_state_sparse_csr_sqrt(x)
else:
raise ValueError("Currently, sparse.sqrt only support the input of SparseCooTensor or SparseCsrTensor")

8 changes: 4 additions & 4 deletions python/paddle/utils/code_gen/sparse_bw_api.yaml
Expand Up @@ -34,28 +34,28 @@

- backward_api : sparse_coo_relu_grad
forward : sparse_coo_relu(Tensor x) -> Tensor(out@SparseCooTensor)
args : (Tensor x, Tensor out_grad)
args : (Tensor out, Tensor out_grad)
output : Tensor(x_grad@SparseCooTensor)
kernel :
func : sparse_coo_relu_grad

- backward_api : sparse_csr_relu_grad
forward : sparse_csr_relu(Tensor x) -> Tensor(out@SparseCsrTensor)
args : (Tensor x, Tensor out_grad)
args : (Tensor out, Tensor out_grad)
output : Tensor(x_grad@SparseCsrTensor)
kernel :
func : sparse_csr_relu_grad

- backward_api : sparse_coo_sqrt_grad
forward : sparse_coo_sqrt(Tensor x) -> Tensor(out@SparseCooTensor)
args : (Tensor x, Tensor out_grad)
args : (Tensor out, Tensor out_grad)
output : Tensor(x_grad@SparseCooTensor)
kernel :
func : sparse_coo_sqrt_grad

- backward_api : sparse_csr_sqrt_grad
forward : sparse_csr_sqrt(Tensor x) -> Tensor(out@SparseCsrTensor)
args : (Tensor x, Tensor out_grad)
args : (Tensor out, Tensor out_grad)
output : Tensor(x_grad@SparseCsrTensor)
kernel :
func : sparse_csr_sqrt_grad

0 comments on commit 71864fd

Please sign in to comment.