Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sparse SyncBatchNorm #43520

Merged
merged 9 commits into from Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 34 additions & 5 deletions python/paddle/fluid/tests/unittests/test_sparse_norm_op.py
Expand Up @@ -16,6 +16,7 @@
import unittest
import numpy as np
import paddle
from paddle.incubate.sparse import nn
import paddle.fluid as fluid
from paddle.fluid.framework import _test_eager_guard
import copy
Expand Down Expand Up @@ -56,11 +57,10 @@ def test(self):

# test backward
sparse_y.backward(sparse_y)
assert np.allclose(
dense_x.grad.flatten().numpy(),
sparse_x.grad.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)
assert np.allclose(dense_x.grad.flatten().numpy(),
sparse_x.grad.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})

def test_error_layout(self):
Expand All @@ -86,5 +86,34 @@ def test2(self):
# [1, 6, 6, 6, 3]


class TestSyncBatchNorm(unittest.TestCase):

def test_sync_batch_norm(self):
with _test_eager_guard():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是不是不用了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,下一个PR一块删。

x = np.array([[[[0.3, 0.4], [0.3, 0.07]],
[[0.83, 0.37], [0.18, 0.93]]]]).astype('float32')
x = paddle.to_tensor(x)
x = x.to_sparse_coo(len(x.shape) - 1)

if paddle.is_compiled_with_cuda():
sync_batch_norm = nn.SyncBatchNorm(2)
hidden1 = sync_batch_norm(x)
print(hidden1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以删了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,下一个PR一块删。


def test_convert(self):
base_model = paddle.nn.Sequential(nn.Conv3D(3, 5, 3), nn.BatchNorm(5),
nn.BatchNorm(5))

model = paddle.nn.Sequential(
nn.Conv3D(3, 5, 3), nn.BatchNorm(5),
nn.BatchNorm(5,
weight_attr=fluid.ParamAttr(name='bn.scale'),
bias_attr=fluid.ParamAttr(name='bn.bias')))
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
for idx, sublayer in enumerate(base_model.sublayers()):
if isinstance(sublayer, nn.BatchNorm):
self.assertEqual(isinstance(model[idx], nn.SyncBatchNorm), True)


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion python/paddle/incubate/sparse/nn/__init__.py
Expand Up @@ -15,10 +15,10 @@
from . import functional

from .layer.activation import ReLU
from .layer.norm import BatchNorm, SyncBatchNorm
from .layer.activation import Softmax
from .layer.activation import ReLU6
from .layer.activation import LeakyReLU
from .layer.norm import BatchNorm
from .layer.conv import Conv3D
from .layer.conv import SubmConv3D
from .layer.pooling import MaxPool3D
Expand All @@ -29,6 +29,7 @@
'LeakyReLU',
'Softmax',
'BatchNorm',
'SyncBatchNorm',
'Conv3D',
'SubmConv3D',
'MaxPool3D',
Expand Down
176 changes: 176 additions & 0 deletions python/paddle/incubate/sparse/nn/layer/norm.py
Expand Up @@ -27,6 +27,8 @@

import paddle
import warnings
from paddle.nn.layer.norm import _BatchNormBase
from paddle.framework import no_grad


class BatchNorm(paddle.nn.BatchNorm1D):
Expand Down Expand Up @@ -157,3 +159,177 @@ def forward(self, input):
batch_norm_out,
shape=input.shape,
stop_gradient=input.stop_gradient)


class SyncBatchNorm(paddle.nn.SyncBatchNorm):
r"""
This interface is used to construct a callable object of the ``SyncBatchNorm`` class.
It implements the function of the Cross-GPU Synchronized Batch Normalization Layer, and can
be used as a normalizer function for other operations, such as conv2d and fully connected
operations.
The data is normalized by the mean and variance of the channel based on whole mini-batch
, which including data in all gpus.
Refer to `Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift <https://arxiv.org/pdf/1502.03167.pdf>`_
for more details.

When model in training mode, the :math:`\\mu_{\\beta}`
and :math:`\\sigma_{\\beta}^{2}` are the statistics of whole mini-batch data in all gpus.
Calculated as follows:

.. math::

\mu_{\beta} &\gets \frac{1}{m} \sum_{i=1}^{m} x_i \qquad &//\
\ mini-batch\ mean \\
\sigma_{\beta}^{2} &\gets \frac{1}{m} \sum_{i=1}^{m}(x_i - \
\mu_{\beta})^2 \qquad &//\ mini-batch\ variance \\

- :math:`x` : whole mini-batch data in all gpus
- :math:`m` : the size of the whole mini-batch data

When model in evaluation mode, the :math:`\\mu_{\\beta}`
and :math:`\sigma_{\beta}^{2}` are global statistics (moving_mean and moving_variance,
which usually got from the pre-trained model). Global statistics calculated as follows:

.. math::
moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global \ mean \\
moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global \ variance \\

The formula of normalization is as follows:

.. math::

\hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\
\sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\
y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift

- :math:`\epsilon` : add a smaller value to the variance to prevent division by zero
- :math:`\gamma` : trainable scale parameter vector
- :math:`\beta` : trainable shift parameter vector

Note:
If you want to use container to pack your model and has ``SyncBatchNorm`` in the
evaluation phase, please use ``nn.LayerList`` or ``nn.Sequential`` instead of
``list`` to pack the model.

Parameters:
num_features(int): Indicate the number of channels of the input ``Tensor``.
epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale`
of this layer. If it is set to None or one attribute of ParamAttr, this layerr
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. If it is set to False,
this layer will not have trainable scale parameter. Default: None.
bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of this layer.
If it is set to None or one attribute of ParamAttr, this layer
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. If it is set to False, this layer will not
have trainable bias parameter. Default: None.

Shapes:
input: Tensor that the dimension from 2 to 5.
output: Tensor with the same shape as input.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output 和intput 那里可以加一个换行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好,我下一个PR一起改。


Examples:
.. code-block:: python

# required: gpu
import paddle
import paddle.incubate.sparse.nn as nn
import numpy as np

x = np.array([[[[0.3, 0.4], [0.3, 0.07]], [[0.83, 0.37], [0.18, 0.93]]]]).astype('float32')
x = paddle.to_tensor(x)
x = x.to_sparse_coo(len(x.shape)-1)

if paddle.is_compiled_with_cuda():
sync_batch_norm = nn.SyncBatchNorm(2)
hidden1 = sync_batch_norm(x)
print(hidden1)
# Tensor(shape=[1, 2, 2, 2], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True,
# indices=[[0, 0, 0, 0],
# [0, 0, 1, 1],
# [0, 1, 0, 1]],
# values=[[-0.40730840, -0.13725480],
# [-0.40730840, -1.20299828],
# [ 1.69877410, -0.23414057],
# [-0.88415730, 1.57439375]])
"""

def __init__(self,
num_features,
momentum=0.9,
epsilon=1e-05,
weight_attr=None,
bias_attr=None,
data_format='NCHW',
name=None):
super(SyncBatchNorm,
self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, name)

def forward(self, x):
assert x.is_sparse_coo(
), "SyncBatchNorm only support SparseTensor in COO format."
out = super(SyncBatchNorm, self).forward(x.values())
return paddle.incubate.sparse.sparse_coo_tensor(
x.indices(), out, shape=x.shape, stop_gradient=x.stop_gradient)

@classmethod
def convert_sync_batchnorm(cls, layer):
"""
Helper function to convert :class: `paddle.incubate.sparse.nn.BatchNorm` layers in the model to :class: `paddle.incubate.sparse.nn.SyncBatchNorm` layers.

Parameters:
layer(paddle.nn.Layer): model containing one or more `BatchNorm` layers.

Returns:
The original model with converted SyncBatchNorm layers. If BatchNorm layer in the model, use SyncBatchNorm layer instead.

Examples:

.. code-block:: python
import paddle
import paddle.incubate.sparse.nn as nn

model = paddle.nn.Sequential(nn.Conv3D(3, 5, 3), nn.BatchNorm(5))
sync_model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

"""
layer_output = layer
if isinstance(layer, _BatchNormBase):
if layer._weight_attr != None and not isinstance(
layer._weight_attr,
bool) and layer._weight_attr.name != None:
layer._weight_attr.name = layer._weight_attr.name + '_sync'
if layer._bias_attr != None and not isinstance(
layer._bias_attr, bool) and layer._bias_attr.name != None:
layer._bias_attr.name = layer._bias_attr.name + '_sync'

#convert sparse BatchNorm
if isinstance(layer, BatchNorm):
layer_output = SyncBatchNorm(layer._num_features,
layer._momentum, layer._epsilon,
layer._weight_attr,
layer._bias_attr,
layer._data_format, layer._name)
#convert dense BatchNorm
else:
layer_output = paddle.nn.SyncBatchNorm(
layer._num_features, layer._momentum, layer._epsilon,
layer._weight_attr, layer._bias_attr, layer._data_format,
layer._name)

if layer._weight_attr != False and layer._bias_attr != False:
with no_grad():
layer_output.weight = layer.weight
layer_output.bias = layer.bias
layer_output._mean = layer._mean
layer_output._variance = layer._variance

for name, sublayer in layer.named_children():
layer_output.add_sublayer(name,
cls.convert_sync_batchnorm(sublayer))
del layer
return layer_output