Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sparse] Add sparse unary api(expm1/deg2rad/rad2deg/relu6/leaky_relu) #44432

Merged
merged 1 commit into from Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 27 additions & 0 deletions paddle/phi/api/yaml/sparse_api.yaml
Expand Up @@ -127,6 +127,24 @@
divide_csr_scalar{sparse_csr -> sparse_csr}
backward : divide_scalar_grad

- api : expm1
args : (Tensor x)
output : Tensor(out)
kernel :
func : expm1_coo{sparse_coo -> sparse_coo},
expm1_csr{sparse_csr -> sparse_csr}
layout : x
backward : expm1_grad

- api : leaky_relu
args : (Tensor x, float alpha)
output : Tensor(out)
kernel :
func : leaky_relu_coo{sparse_coo -> sparse_coo},
leaky_relu_csr{sparse_csr -> sparse_csr}
layout : x
backward : leaky_relu_grad

- api : log1p
args : (Tensor x)
output : Tensor(out)
Expand Down Expand Up @@ -163,6 +181,15 @@
layout : x
backward : relu_grad

- api : relu6
args : (Tensor x, float threshold)
output : Tensor(out)
kernel :
func : relu6_coo{sparse_coo -> sparse_coo},
relu6_csr{sparse_csr -> sparse_csr}
layout : x
backward : relu6_grad

- api : scale
args : (Tensor x, float scale, float bias, bool bias_after_scale)
output : Tensor(out)
Expand Down
26 changes: 25 additions & 1 deletion paddle/phi/api/yaml/sparse_bw_api.yaml
Expand Up @@ -112,6 +112,22 @@
output : Tensor(x_grad)
invoke : divide_scalar(out_grad, scalar)

- backward_api : expm1_grad
forward : expm1(Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
output : Tensor(x_grad)
kernel :
func : expm1_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
expm1_csr_grad {sparse_csr, sparse_csr -> sparse_csr}

- backward_api : leaky_relu_grad
forward : leaky_relu(Tensor x, float alpha) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float alpha)
output : Tensor(x_grad)
kernel :
func : leaky_relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
leaky_relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr}

- backward_api : log1p_grad
forward : log1p(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down Expand Up @@ -161,6 +177,14 @@
func : pow_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
pow_csr_grad {sparse_csr, sparse_csr -> sparse_csr}

- backward_api : relu6_grad
forward : relu6(Tensor x, float threshold) -> Tensor(out)
args : (Tensor out, Tensor out_grad, float threshold)
output : Tensor(x_grad)
kernel :
func : relu6_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
relu6_csr_grad {sparse_csr, sparse_csr -> sparse_csr}

- backward_api : relu_grad
forward : relu(Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
Expand Down Expand Up @@ -255,7 +279,7 @@
- backward_api: fused_attention_grad
forward : fused_attention_csr(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax)
args: (Tensor query, Tensor key, Tensor value, Tensor softmax, Tensor out_grad)
output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad)
output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad)
kernel :
func : fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense}
layout : softmax
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/activation_grad_kernel.h
Expand Up @@ -240,11 +240,11 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, beta);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Logit, eps);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, alpha);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(Relu6, threshold);

DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, t_min, t_max);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(STanh, scale_a, scale_b);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus, beta, threshold);

DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid, slope, offset);

} // namespace phi
3 changes: 3 additions & 0 deletions paddle/phi/kernels/sparse/cpu/unary_grad_kernel.cc
Expand Up @@ -51,6 +51,9 @@ PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(log1p, Log1p)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(relu, Relu)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(abs, Abs)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(pow, Pow)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(expm1, Expm1)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(relu6, Relu6)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(leaky_relu, LeakyRelu)

PD_REGISTER_KERNEL(cast_coo_grad,
CPU,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/sparse/cpu/unary_kernel.cc
Expand Up @@ -93,6 +93,9 @@ PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu, Relu)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(abs, Abs)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(pow, Pow)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(scale, Scale)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(expm1, Expm1)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu6, Relu6)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(leaky_relu, LeakyRelu)

PD_REGISTER_KERNEL(divide_coo_scalar,
CPU,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu
Expand Up @@ -53,6 +53,9 @@ PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(log1p, Log1p)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(relu, Relu)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(abs, Abs)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(pow, Pow)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(expm1, Expm1)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(relu6, Relu6)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(leaky_relu, LeakyRelu)

PD_REGISTER_KERNEL(cast_coo_grad,
GPU,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/sparse/gpu/unary_kernel.cu
Expand Up @@ -98,6 +98,9 @@ PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu, Relu)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(abs, Abs)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(pow, Pow)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(scale, Scale)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(expm1, Expm1)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu6, Relu6)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(leaky_relu, LeakyRelu)

PD_REGISTER_KERNEL(divide_coo_scalar,
GPU,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h
Expand Up @@ -93,7 +93,10 @@ DEFINE_SPARSE_UNARY_GRAD_KERNEL(Square)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Log1p)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Relu)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Abs)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Expm1)
DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Pow, factor)
DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha)
DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Relu6, threshold)

template <typename T, typename Context>
void CastCooGradKernel(const Context& dev_ctx,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/sparse/impl/unary_kernel_impl.h
Expand Up @@ -86,7 +86,10 @@ DEFINE_SPARSE_UNARY_KERNEL(Square)
DEFINE_SPARSE_UNARY_KERNEL(Log1p)
DEFINE_SPARSE_UNARY_KERNEL(Relu)
DEFINE_SPARSE_UNARY_KERNEL(Abs)
DEFINE_SPARSE_UNARY_KERNEL(Expm1)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Relu6, threshold)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha)

template <typename T, typename Context>
void ScaleCooKernel(const Context& dev_ctx,
Expand Down
17 changes: 17 additions & 0 deletions python/paddle/fluid/tests/unittests/test_sparse_unary_op.py
Expand Up @@ -123,9 +123,26 @@ def test_sparse_relu(self):
self.compare_with_dense(paddle.nn.ReLU(),
paddle.incubate.sparse.nn.ReLU())

def test_sparse_relu6(self):
self.compare_with_dense(paddle.nn.ReLU6(),
paddle.incubate.sparse.nn.ReLU6())

def test_sparse_leaky_relu(self):
self.compare_with_dense(paddle.nn.LeakyReLU(0.1),
paddle.incubate.sparse.nn.LeakyReLU(0.1))

def test_sparse_abs(self):
self.compare_with_dense(paddle.abs, paddle.incubate.sparse.abs)

def test_sparse_expm1(self):
self.compare_with_dense(paddle.expm1, paddle.incubate.sparse.expm1)

def test_sparse_deg2rad(self):
self.compare_with_dense(paddle.deg2rad, paddle.incubate.sparse.deg2rad)

def test_sparse_rad2deg(self):
self.compare_with_dense(paddle.rad2deg, paddle.incubate.sparse.rad2deg)

def test_sparse_neg(self):
self.compare_with_dense(paddle.neg, paddle.incubate.sparse.neg)

Expand Down
6 changes: 6 additions & 0 deletions python/paddle/incubate/sparse/__init__.py
Expand Up @@ -31,6 +31,9 @@
from .unary import cast
from .unary import neg
from .unary import coalesce
from .unary import deg2rad
from .unary import rad2deg
from .unary import expm1

from .binary import mv
from .binary import matmul
Expand Down Expand Up @@ -60,6 +63,9 @@
'pow',
'cast',
'neg',
'deg2rad',
'rad2deg',
'expm1',
'mv',
'matmul',
'masked_matmul',
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/incubate/sparse/nn/__init__.py
Expand Up @@ -16,13 +16,17 @@

from .layer.activation import ReLU
from .layer.activation import Softmax
from .layer.activation import ReLU6
from .layer.activation import LeakyReLU
from .layer.norm import BatchNorm
from .layer.conv import Conv3D
from .layer.conv import SubmConv3D
from .layer.pooling import MaxPool3D

__all__ = [
'ReLU',
'ReLU6',
'LeakyReLU',
'Softmax',
'BatchNorm',
'Conv3D',
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/incubate/sparse/nn/functional/__init__.py
Expand Up @@ -17,13 +17,17 @@
from .transformer import attention # noqa: F401
from .pooling import max_pool3d # noqa: F401
from .activation import relu # noqa: F401
from .activation import relu6 # noqa: F401
from .activation import leaky_relu # noqa: F401
from .activation import softmax # noqa: F401

__all__ = [
'conv3d',
'subm_conv3d',
'max_pool3d',
'relu',
'relu6',
'leaky_relu',
'softmax',
'attention',
]
119 changes: 90 additions & 29 deletions python/paddle/incubate/sparse/nn/functional/activation.py
Expand Up @@ -21,7 +21,7 @@
@dygraph_only
def relu(x, name=None):
"""
sparse relu activation, requiring x to be a sparse coo or sparse csr tensor.
sparse relu activation, requiring x to be a SparseCooTensor or SparseCsrTensor.

.. math::

Expand All @@ -39,20 +39,19 @@ def relu(x, name=None):
.. code-block:: python

import paddle
from paddle.fluid.framework import _test_eager_guard

with _test_eager_guard():
dense_x = paddle.to_tensor([-2, 0, 1], dtype='float32')
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.nn.functional.relu(sparse_x)
dense_x = paddle.to_tensor([-2., 0., 1.])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.nn.functional.relu(sparse_x)
# [0., 0., 1.]
"""
return _C_ops.final_state_sparse_relu(x)


@dygraph_only
def softmax(x, axis=-1, name=None):
"""
sparse softmax activation, x must be SparseCsrTensor or SparseCooTensor.
sparse softmax activation, requiring x to be a SparseCooTensor or SparseCsrTensor.

Note:
Only support axis=-1 for SparseCsrTensor, which is faster when read data
Expand All @@ -79,30 +78,92 @@ def softmax(x, axis=-1, name=None):

import paddle
import numpy as np
from paddle.fluid.framework import _test_eager_guard

paddle.seed(100)

with _test_eager_guard():
mask = np.random.rand(3, 4) < 0.5
np_x = np.random.rand(3, 4) * mask
# [[0. 0. 0.96823406 0.19722934]
# [0.94373937 0. 0.02060066 0.71456372]
# [0. 0. 0. 0.98275049]]

csr = paddle.to_tensor(np_x).to_sparse_csr()
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# crows=[0, 2, 5, 6],
# cols=[2, 3, 0, 2, 3, 3],
# values=[0.96823406, 0.19722934, 0.94373937, 0.02060066, 0.71456372,
# 0.98275049])

out = paddle.incubate.sparse.nn.functional.softmax(csr)
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# crows=[0, 2, 5, 6],
# cols=[2, 3, 0, 2, 3, 3],
# values=[0.68373820, 0.31626180, 0.45610887, 0.18119845, 0.36269269,
# 1. ])
mask = np.random.rand(3, 4) < 0.5
np_x = np.random.rand(3, 4) * mask
# [[0. 0. 0.96823406 0.19722934]
# [0.94373937 0. 0.02060066 0.71456372]
# [0. 0. 0. 0.98275049]]

csr = paddle.to_tensor(np_x).to_sparse_csr()
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# crows=[0, 2, 5, 6],
# cols=[2, 3, 0, 2, 3, 3],
# values=[0.96823406, 0.19722934, 0.94373937, 0.02060066, 0.71456372,
# 0.98275049])

out = paddle.incubate.sparse.nn.functional.softmax(csr)
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# crows=[0, 2, 5, 6],
# cols=[2, 3, 0, 2, 3, 3],
# values=[0.68373820, 0.31626180, 0.45610887, 0.18119845, 0.36269269,
# 1. ])

"""
return _C_ops.final_state_sparse_softmax(x, axis)


@dygraph_only
def relu6(x, name=None):
"""
sparse relu6 activation, requiring x to be a SparseCooTensor or SparseCsrTensor.

.. math::

relu6(x) = min(max(0, x), 6)

Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.

Returns:
A Sparse Tensor with the same data type and shape as ``x`` .

Examples:
.. code-block:: python

import paddle

dense_x = paddle.to_tensor([-2., 0., 8.])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.nn.functional.relu6(sparse_x)
"""
return _C_ops.final_state_sparse_relu6(x, 6.0)


@dygraph_only
def leaky_relu(x, negative_slope=0.01, name=None):
"""
sparse leaky_relu activation, requiring x to be a SparseCooTensor or SparseCsrTensor.

.. math::
leaky\_relu(x)=
\left\{
\begin{array}{rcl}
x, & & if \ x >= 0 \\
negative\_slope * x, & & otherwise \\
\end{array}
\right.

Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
negative_slope (float, optional): Slope of the activation function at
:math:`x < 0` . Default is 0.01.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.

Returns:
A Sparse Tensor with the same data type and shape as ``x`` .

Examples:
.. code-block:: python

import paddle

dense_x = paddle.to_tensor([-2., 0., 5.])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.nn.functional.leaky_relu(sparse_x, 0.5)
"""
return _C_ops.final_state_sparse_leaky_relu(x, negative_slope)