diff --git a/python/paddle/fluid/tests/unittests/test_sparse_norm_op.py b/python/paddle/fluid/tests/unittests/test_sparse_norm_op.py index 8eccefed6ef64..31d3e380c5196 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_norm_op.py @@ -16,6 +16,7 @@ import unittest import numpy as np import paddle +from paddle.incubate.sparse import nn import paddle.fluid as fluid from paddle.fluid.framework import _test_eager_guard import copy @@ -56,11 +57,10 @@ def test(self): # test backward sparse_y.backward(sparse_y) - assert np.allclose( - dense_x.grad.flatten().numpy(), - sparse_x.grad.values().flatten().numpy(), - atol=1e-5, - rtol=1e-5) + assert np.allclose(dense_x.grad.flatten().numpy(), + sparse_x.grad.values().flatten().numpy(), + atol=1e-5, + rtol=1e-5) fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False}) def test_error_layout(self): @@ -86,5 +86,34 @@ def test2(self): # [1, 6, 6, 6, 3] +class TestSyncBatchNorm(unittest.TestCase): + + def test_sync_batch_norm(self): + with _test_eager_guard(): + x = np.array([[[[0.3, 0.4], [0.3, 0.07]], + [[0.83, 0.37], [0.18, 0.93]]]]).astype('float32') + x = paddle.to_tensor(x) + x = x.to_sparse_coo(len(x.shape) - 1) + + if paddle.is_compiled_with_cuda(): + sync_batch_norm = nn.SyncBatchNorm(2) + hidden1 = sync_batch_norm(x) + print(hidden1) + + def test_convert(self): + base_model = paddle.nn.Sequential(nn.Conv3D(3, 5, 3), nn.BatchNorm(5), + nn.BatchNorm(5)) + + model = paddle.nn.Sequential( + nn.Conv3D(3, 5, 3), nn.BatchNorm(5), + nn.BatchNorm(5, + weight_attr=fluid.ParamAttr(name='bn.scale'), + bias_attr=fluid.ParamAttr(name='bn.bias'))) + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + for idx, sublayer in enumerate(base_model.sublayers()): + if isinstance(sublayer, nn.BatchNorm): + self.assertEqual(isinstance(model[idx], nn.SyncBatchNorm), True) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/incubate/sparse/nn/__init__.py b/python/paddle/incubate/sparse/nn/__init__.py index 1d5889753716a..bb4fa18877be4 100644 --- a/python/paddle/incubate/sparse/nn/__init__.py +++ b/python/paddle/incubate/sparse/nn/__init__.py @@ -15,10 +15,10 @@ from . import functional from .layer.activation import ReLU +from .layer.norm import BatchNorm, SyncBatchNorm from .layer.activation import Softmax from .layer.activation import ReLU6 from .layer.activation import LeakyReLU -from .layer.norm import BatchNorm from .layer.conv import Conv3D from .layer.conv import SubmConv3D from .layer.pooling import MaxPool3D @@ -29,6 +29,7 @@ 'LeakyReLU', 'Softmax', 'BatchNorm', + 'SyncBatchNorm', 'Conv3D', 'SubmConv3D', 'MaxPool3D', diff --git a/python/paddle/incubate/sparse/nn/layer/norm.py b/python/paddle/incubate/sparse/nn/layer/norm.py index 4d4cf7df2f2e4..2dbefcd4bfedc 100644 --- a/python/paddle/incubate/sparse/nn/layer/norm.py +++ b/python/paddle/incubate/sparse/nn/layer/norm.py @@ -27,6 +27,8 @@ import paddle import warnings +from paddle.nn.layer.norm import _BatchNormBase +from paddle.framework import no_grad class BatchNorm(paddle.nn.BatchNorm1D): @@ -157,3 +159,177 @@ def forward(self, input): batch_norm_out, shape=input.shape, stop_gradient=input.stop_gradient) + + +class SyncBatchNorm(paddle.nn.SyncBatchNorm): + r""" + This interface is used to construct a callable object of the ``SyncBatchNorm`` class. + It implements the function of the Cross-GPU Synchronized Batch Normalization Layer, and can + be used as a normalizer function for other operations, such as conv2d and fully connected + operations. + The data is normalized by the mean and variance of the channel based on whole mini-batch + , which including data in all gpus. + Refer to `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `_ + for more details. + + When model in training mode, the :math:`\\mu_{\\beta}` + and :math:`\\sigma_{\\beta}^{2}` are the statistics of whole mini-batch data in all gpus. + Calculated as follows: + + .. math:: + + \mu_{\beta} &\gets \frac{1}{m} \sum_{i=1}^{m} x_i \qquad &//\ + \ mini-batch\ mean \\ + \sigma_{\beta}^{2} &\gets \frac{1}{m} \sum_{i=1}^{m}(x_i - \ + \mu_{\beta})^2 \qquad &//\ mini-batch\ variance \\ + + - :math:`x` : whole mini-batch data in all gpus + - :math:`m` : the size of the whole mini-batch data + + When model in evaluation mode, the :math:`\\mu_{\\beta}` + and :math:`\sigma_{\beta}^{2}` are global statistics (moving_mean and moving_variance, + which usually got from the pre-trained model). Global statistics calculated as follows: + + .. math:: + moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global \ mean \\ + moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global \ variance \\ + + The formula of normalization is as follows: + + .. math:: + + \hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\ + \sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\ + y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift + + - :math:`\epsilon` : add a smaller value to the variance to prevent division by zero + - :math:`\gamma` : trainable scale parameter vector + - :math:`\beta` : trainable shift parameter vector + + Note: + If you want to use container to pack your model and has ``SyncBatchNorm`` in the + evaluation phase, please use ``nn.LayerList`` or ``nn.Sequential`` instead of + ``list`` to pack the model. + + Parameters: + num_features(int): Indicate the number of channels of the input ``Tensor``. + epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. + momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. + weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale` + of this layer. If it is set to None or one attribute of ParamAttr, this layerr + will create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with Xavier. If it is set to False, + this layer will not have trainable scale parameter. Default: None. + bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of this layer. + If it is set to None or one attribute of ParamAttr, this layer + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. If it is set to False, this layer will not + have trainable bias parameter. Default: None. + + Shapes: + input: Tensor that the dimension from 2 to 5. + output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + + # required: gpu + import paddle + import paddle.incubate.sparse.nn as nn + import numpy as np + + x = np.array([[[[0.3, 0.4], [0.3, 0.07]], [[0.83, 0.37], [0.18, 0.93]]]]).astype('float32') + x = paddle.to_tensor(x) + x = x.to_sparse_coo(len(x.shape)-1) + + if paddle.is_compiled_with_cuda(): + sync_batch_norm = nn.SyncBatchNorm(2) + hidden1 = sync_batch_norm(x) + print(hidden1) + # Tensor(shape=[1, 2, 2, 2], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, + # indices=[[0, 0, 0, 0], + # [0, 0, 1, 1], + # [0, 1, 0, 1]], + # values=[[-0.40730840, -0.13725480], + # [-0.40730840, -1.20299828], + # [ 1.69877410, -0.23414057], + # [-0.88415730, 1.57439375]]) + """ + + def __init__(self, + num_features, + momentum=0.9, + epsilon=1e-05, + weight_attr=None, + bias_attr=None, + data_format='NCHW', + name=None): + super(SyncBatchNorm, + self).__init__(num_features, momentum, epsilon, weight_attr, + bias_attr, data_format, name) + + def forward(self, x): + assert x.is_sparse_coo( + ), "SyncBatchNorm only support SparseTensor in COO format." + out = super(SyncBatchNorm, self).forward(x.values()) + return paddle.incubate.sparse.sparse_coo_tensor( + x.indices(), out, shape=x.shape, stop_gradient=x.stop_gradient) + + @classmethod + def convert_sync_batchnorm(cls, layer): + """ + Helper function to convert :class: `paddle.incubate.sparse.nn.BatchNorm` layers in the model to :class: `paddle.incubate.sparse.nn.SyncBatchNorm` layers. + + Parameters: + layer(paddle.nn.Layer): model containing one or more `BatchNorm` layers. + + Returns: + The original model with converted SyncBatchNorm layers. If BatchNorm layer in the model, use SyncBatchNorm layer instead. + + Examples: + + .. code-block:: python + import paddle + import paddle.incubate.sparse.nn as nn + + model = paddle.nn.Sequential(nn.Conv3D(3, 5, 3), nn.BatchNorm(5)) + sync_model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + + """ + layer_output = layer + if isinstance(layer, _BatchNormBase): + if layer._weight_attr != None and not isinstance( + layer._weight_attr, + bool) and layer._weight_attr.name != None: + layer._weight_attr.name = layer._weight_attr.name + '_sync' + if layer._bias_attr != None and not isinstance( + layer._bias_attr, bool) and layer._bias_attr.name != None: + layer._bias_attr.name = layer._bias_attr.name + '_sync' + + #convert sparse BatchNorm + if isinstance(layer, BatchNorm): + layer_output = SyncBatchNorm(layer._num_features, + layer._momentum, layer._epsilon, + layer._weight_attr, + layer._bias_attr, + layer._data_format, layer._name) + #convert dense BatchNorm + else: + layer_output = paddle.nn.SyncBatchNorm( + layer._num_features, layer._momentum, layer._epsilon, + layer._weight_attr, layer._bias_attr, layer._data_format, + layer._name) + + if layer._weight_attr != False and layer._bias_attr != False: + with no_grad(): + layer_output.weight = layer.weight + layer_output.bias = layer.bias + layer_output._mean = layer._mean + layer_output._variance = layer._variance + + for name, sublayer in layer.named_children(): + layer_output.add_sublayer(name, + cls.convert_sync_batchnorm(sublayer)) + del layer + return layer_output