Skip to content

Commit

Permalink
[Sparse]add sparse unary api(expm1/deg2rad/rad2deg/relu6/leaky_relu) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 committed Jul 22, 2022
1 parent 18c7732 commit 19d9c73
Show file tree
Hide file tree
Showing 16 changed files with 422 additions and 75 deletions.
27 changes: 27 additions & 0 deletions paddle/phi/api/yaml/sparse_api.yaml
Expand Up @@ -127,6 +127,24 @@
divide_csr_scalar{sparse_csr -> sparse_csr}
backward : divide_scalar_grad

- api : expm1
args : (Tensor x)
output : Tensor(out)
kernel :
func : expm1_coo{sparse_coo -> sparse_coo},
expm1_csr{sparse_csr -> sparse_csr}
layout : x
backward : expm1_grad

- api : leaky_relu
args : (Tensor x, float alpha)
output : Tensor(out)
kernel :
func : leaky_relu_coo{sparse_coo -> sparse_coo},
leaky_relu_csr{sparse_csr -> sparse_csr}
layout : x
backward : leaky_relu_grad

- api : log1p
args : (Tensor x)
output : Tensor(out)
Expand Down Expand Up @@ -163,6 +181,15 @@
layout : x
backward : relu_grad

- api : relu6
args : (Tensor x, float threshold)
output : Tensor(out)
kernel :
func : relu6_coo{sparse_coo -> sparse_coo},
relu6_csr{sparse_csr -> sparse_csr}
layout : x
backward : relu6_grad

- api : scale
args : (Tensor x, float scale, float bias, bool bias_after_scale)
output : Tensor(out)
Expand Down
24 changes: 24 additions & 0 deletions paddle/phi/api/yaml/sparse_bw_api.yaml
Expand Up @@ -122,6 +122,22 @@
output : Tensor(x_grad)
invoke : divide_scalar(out_grad, scalar)

- backward_api : expm1_grad
forward : expm1(Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
output : Tensor(x_grad)
kernel :
func : expm1_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
expm1_csr_grad {sparse_csr, sparse_csr -> sparse_csr}

- backward_api : leaky_relu_grad
forward : leaky_relu(Tensor x, float alpha) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float alpha)
output : Tensor(x_grad)
kernel :
func : leaky_relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
leaky_relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr}

- backward_api : log1p_grad
forward : log1p(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down Expand Up @@ -178,6 +194,14 @@
func : pow_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
pow_csr_grad {sparse_csr, sparse_csr -> sparse_csr}

- backward_api : relu6_grad
forward : relu6(Tensor x, float threshold) -> Tensor(out)
args : (Tensor out, Tensor out_grad, float threshold)
output : Tensor(x_grad)
kernel :
func : relu6_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
relu6_csr_grad {sparse_csr, sparse_csr -> sparse_csr}

- backward_api : relu_grad
forward : relu(Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/activation_grad_kernel.h
Expand Up @@ -240,11 +240,11 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, beta);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Logit, eps);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, alpha);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(Relu6, threshold);

DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, t_min, t_max);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(STanh, scale_a, scale_b);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus, beta, threshold);

DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid, slope, offset);

} // namespace phi
3 changes: 3 additions & 0 deletions paddle/phi/kernels/sparse/cpu/unary_grad_kernel.cc
Expand Up @@ -51,6 +51,9 @@ PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(log1p, Log1p)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(relu, Relu)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(abs, Abs)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(pow, Pow)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(expm1, Expm1)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(relu6, Relu6)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(leaky_relu, LeakyRelu)

PD_REGISTER_KERNEL(cast_coo_grad,
CPU,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/sparse/cpu/unary_kernel.cc
Expand Up @@ -93,6 +93,9 @@ PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu, Relu)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(abs, Abs)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(pow, Pow)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(scale, Scale)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(expm1, Expm1)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu6, Relu6)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(leaky_relu, LeakyRelu)

PD_REGISTER_KERNEL(divide_coo_scalar,
CPU,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu
Expand Up @@ -53,6 +53,9 @@ PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(log1p, Log1p)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(relu, Relu)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(abs, Abs)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(pow, Pow)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(expm1, Expm1)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(relu6, Relu6)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(leaky_relu, LeakyRelu)

PD_REGISTER_KERNEL(cast_coo_grad,
GPU,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/sparse/gpu/unary_kernel.cu
Expand Up @@ -98,6 +98,9 @@ PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu, Relu)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(abs, Abs)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(pow, Pow)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(scale, Scale)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(expm1, Expm1)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu6, Relu6)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(leaky_relu, LeakyRelu)

PD_REGISTER_KERNEL(divide_coo_scalar,
GPU,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h
Expand Up @@ -93,7 +93,10 @@ DEFINE_SPARSE_UNARY_GRAD_KERNEL(Square)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Log1p)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Relu)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Abs)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Expm1)
DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Pow, factor)
DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha)
DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Relu6, threshold)

template <typename T, typename Context>
void CastCooGradKernel(const Context& dev_ctx,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/sparse/impl/unary_kernel_impl.h
Expand Up @@ -86,7 +86,10 @@ DEFINE_SPARSE_UNARY_KERNEL(Square)
DEFINE_SPARSE_UNARY_KERNEL(Log1p)
DEFINE_SPARSE_UNARY_KERNEL(Relu)
DEFINE_SPARSE_UNARY_KERNEL(Abs)
DEFINE_SPARSE_UNARY_KERNEL(Expm1)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Relu6, threshold)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha)

template <typename T, typename Context>
void ScaleCooKernel(const Context& dev_ctx,
Expand Down
17 changes: 17 additions & 0 deletions python/paddle/fluid/tests/unittests/test_sparse_unary_op.py
Expand Up @@ -123,9 +123,26 @@ def test_sparse_relu(self):
self.compare_with_dense(paddle.nn.ReLU(),
paddle.incubate.sparse.nn.ReLU())

def test_sparse_relu6(self):
self.compare_with_dense(paddle.nn.ReLU6(),
paddle.incubate.sparse.nn.ReLU6())

def test_sparse_leaky_relu(self):
self.compare_with_dense(paddle.nn.LeakyReLU(0.1),
paddle.incubate.sparse.nn.LeakyReLU(0.1))

def test_sparse_abs(self):
self.compare_with_dense(paddle.abs, paddle.incubate.sparse.abs)

def test_sparse_expm1(self):
self.compare_with_dense(paddle.expm1, paddle.incubate.sparse.expm1)

def test_sparse_deg2rad(self):
self.compare_with_dense(paddle.deg2rad, paddle.incubate.sparse.deg2rad)

def test_sparse_rad2deg(self):
self.compare_with_dense(paddle.rad2deg, paddle.incubate.sparse.rad2deg)

def test_sparse_neg(self):
self.compare_with_dense(paddle.neg, paddle.incubate.sparse.neg)

Expand Down
6 changes: 6 additions & 0 deletions python/paddle/incubate/sparse/__init__.py
Expand Up @@ -31,6 +31,9 @@
from .unary import cast
from .unary import neg
from .unary import coalesce
from .unary import deg2rad
from .unary import rad2deg
from .unary import expm1

from .binary import mv
from .binary import matmul
Expand Down Expand Up @@ -62,6 +65,9 @@
'pow',
'cast',
'neg',
'deg2rad',
'rad2deg',
'expm1',
'mv',
'matmul',
'masked_matmul',
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/incubate/sparse/nn/__init__.py
Expand Up @@ -16,13 +16,17 @@

from .layer.activation import ReLU
from .layer.activation import Softmax
from .layer.activation import ReLU6
from .layer.activation import LeakyReLU
from .layer.norm import BatchNorm
from .layer.conv import Conv3D
from .layer.conv import SubmConv3D
from .layer.pooling import MaxPool3D

__all__ = [
'ReLU',
'ReLU6',
'LeakyReLU',
'Softmax',
'BatchNorm',
'Conv3D',
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/incubate/sparse/nn/functional/__init__.py
Expand Up @@ -17,13 +17,17 @@
from .transformer import attention # noqa: F401
from .pooling import max_pool3d # noqa: F401
from .activation import relu # noqa: F401
from .activation import relu6 # noqa: F401
from .activation import leaky_relu # noqa: F401
from .activation import softmax # noqa: F401

__all__ = [
'conv3d',
'subm_conv3d',
'max_pool3d',
'relu',
'relu6',
'leaky_relu',
'softmax',
'attention',
]
119 changes: 90 additions & 29 deletions python/paddle/incubate/sparse/nn/functional/activation.py
Expand Up @@ -21,7 +21,7 @@
@dygraph_only
def relu(x, name=None):
"""
sparse relu activation, requiring x to be a sparse coo or sparse csr tensor.
sparse relu activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
Expand All @@ -39,20 +39,19 @@ def relu(x, name=None):
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard():
dense_x = paddle.to_tensor([-2, 0, 1], dtype='float32')
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.nn.functional.relu(sparse_x)
dense_x = paddle.to_tensor([-2., 0., 1.])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.nn.functional.relu(sparse_x)
# [0., 0., 1.]
"""
return _C_ops.final_state_sparse_relu(x)


@dygraph_only
def softmax(x, axis=-1, name=None):
"""
sparse softmax activation, x must be SparseCsrTensor or SparseCooTensor.
sparse softmax activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
Note:
Only support axis=-1 for SparseCsrTensor, which is faster when read data
Expand All @@ -79,30 +78,92 @@ def softmax(x, axis=-1, name=None):
import paddle
import numpy as np
from paddle.fluid.framework import _test_eager_guard
paddle.seed(100)
with _test_eager_guard():
mask = np.random.rand(3, 4) < 0.5
np_x = np.random.rand(3, 4) * mask
# [[0. 0. 0.96823406 0.19722934]
# [0.94373937 0. 0.02060066 0.71456372]
# [0. 0. 0. 0.98275049]]
csr = paddle.to_tensor(np_x).to_sparse_csr()
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# crows=[0, 2, 5, 6],
# cols=[2, 3, 0, 2, 3, 3],
# values=[0.96823406, 0.19722934, 0.94373937, 0.02060066, 0.71456372,
# 0.98275049])
out = paddle.incubate.sparse.nn.functional.softmax(csr)
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# crows=[0, 2, 5, 6],
# cols=[2, 3, 0, 2, 3, 3],
# values=[0.68373820, 0.31626180, 0.45610887, 0.18119845, 0.36269269,
# 1. ])
mask = np.random.rand(3, 4) < 0.5
np_x = np.random.rand(3, 4) * mask
# [[0. 0. 0.96823406 0.19722934]
# [0.94373937 0. 0.02060066 0.71456372]
# [0. 0. 0. 0.98275049]]
csr = paddle.to_tensor(np_x).to_sparse_csr()
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# crows=[0, 2, 5, 6],
# cols=[2, 3, 0, 2, 3, 3],
# values=[0.96823406, 0.19722934, 0.94373937, 0.02060066, 0.71456372,
# 0.98275049])
out = paddle.incubate.sparse.nn.functional.softmax(csr)
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# crows=[0, 2, 5, 6],
# cols=[2, 3, 0, 2, 3, 3],
# values=[0.68373820, 0.31626180, 0.45610887, 0.18119845, 0.36269269,
# 1. ])
"""
return _C_ops.final_state_sparse_softmax(x, axis)


@dygraph_only
def relu6(x, name=None):
"""
sparse relu6 activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
relu6(x) = min(max(0, x), 6)
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([-2., 0., 8.])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.nn.functional.relu6(sparse_x)
"""
return _C_ops.final_state_sparse_relu6(x, 6.0)


@dygraph_only
def leaky_relu(x, negative_slope=0.01, name=None):
"""
sparse leaky_relu activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
leaky\_relu(x)=
\left\{
\begin{array}{rcl}
x, & & if \ x >= 0 \\
negative\_slope * x, & & otherwise \\
\end{array}
\right.
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
negative_slope (float, optional): Slope of the activation function at
:math:`x < 0` . Default is 0.01.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([-2., 0., 5.])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.nn.functional.leaky_relu(sparse_x, 0.5)
"""
return _C_ops.final_state_sparse_leaky_relu(x, negative_slope)

0 comments on commit 19d9c73

Please sign in to comment.