diff --git a/paddle/phi/kernels/sparse/cpu/coalesced_kernel.cc b/paddle/phi/kernels/sparse/cpu/coalesced_kernel.cc index 0ebddf9b683f0..22c5e14b35f56 100644 --- a/paddle/phi/kernels/sparse/cpu/coalesced_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/coalesced_kernel.cc @@ -44,7 +44,7 @@ void CoalescedCPUKernel(const CPUContext& dev_ctx, const T* x_values_ptr = x_values.data(); const int64_t stride = - x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim; + x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1]; std::map> indices_to_index; for (uint64_t i = 0; i < x_indexs.size(); i++) { diff --git a/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc index 1508de407caa7..0ec8b808ba838 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc @@ -125,7 +125,7 @@ void SparseMaskHelperCPUKernel(const CPUContext& dev_ctx, T* out_ptr = out->data(); memset(out_ptr, static_cast(0), out->numel() * sizeof(T)); const int64_t stride = - x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim; + x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1]; const T* in_ptr = x.non_zero_elements().data(); // TODO(zhangkaihuo): multithreading can be used for acceleration for (uint64_t i = 0; i < mask_indexs.size(); i++) { diff --git a/paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu b/paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu index 3ffcd28955a53..b2e7884580c74 100644 --- a/paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu @@ -76,7 +76,7 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx, // 2. get the address of each non-zero values const T* x_values_ptr = x_values.data(); const int64_t stride = - x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim; + x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1]; DenseTensor values_indexs = phi::Empty( dev_ctx, DenseTensorMeta(DataType::INT32, {nnz}, DataLayout::NCHW)); int* values_indexs_ptr = values_indexs.data(); diff --git a/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu index 4e2d12f33955e..4253845956ea7 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu @@ -231,7 +231,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, T* out_ptr = out->data(); const int64_t stride = - x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim; + x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1]; SparseMaskCopyKernel<< 1: + lens = np.append(lens, values.shape[1:]) + return list(lens) def _get_place(place): @@ -106,7 +111,7 @@ def sparse_coo_tensor(indices, with _test_eager_guard(): indices = [[0, 1, 2], [1, 2, 0]] values = [1.0, 2.0, 3.0] - dense_shape = [2, 3] + dense_shape = [3, 3] coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape) # print(coo) # Tensor(shape=[2, 3], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, @@ -145,7 +150,8 @@ def sparse_coo_tensor(indices, values = _handle_dtype(values, dtype) values.stop_gradient = stop_gradient - min_shape = _infer_dense_shape(indices) + min_shape = _infer_dense_shape(indices, values) + if shape is None: shape = min_shape else: diff --git a/python/paddle/sparse/functional/conv.py b/python/paddle/sparse/functional/conv.py index d8c0e5c914ccb..42b7b49835cf0 100644 --- a/python/paddle/sparse/functional/conv.py +++ b/python/paddle/sparse/functional/conv.py @@ -16,6 +16,8 @@ from paddle import _C_ops, in_dynamic_mode from ...fluid.layers.utils import convert_to_list +from ...fluid.layers.nn import elementwise_add +from .. import sparse_coo_tensor from paddle.nn.functional.conv import _update_padding_nd @@ -30,7 +32,6 @@ def _conv3d(x, data_format="NDHWC", name=None): assert in_dynamic_mode(), "Currently, only support dynamic mode" - assert bias == None, "Currently, sparse_conv3d does not support bias" assert groups == 1, "Currently, only support groups=1" dims = 3 @@ -61,8 +62,18 @@ def _conv3d(x, dilation = convert_to_list(dilation, dims, 'dilation') op_type = "conv3d" - return _C_ops.final_state_sparse_conv3d(x, weight, padding, dilation, - stride, groups, subm) + pre_bias = _C_ops.final_state_sparse_conv3d(x, weight, padding, dilation, + stride, groups, subm) + if bias is not None: + values = pre_bias.values() + add_bias = elementwise_add(values, bias, axis=1) + return sparse_coo_tensor( + pre_bias.indices(), + add_bias, + shape=pre_bias.shape, + stop_gradient=pre_bias.stop_gradient) + else: + return pre_bias def conv3d(x, diff --git a/python/paddle/sparse/layer/__init__.py b/python/paddle/sparse/layer/__init__.py index a0f9d068e677c..ee32e5027b50f 100644 --- a/python/paddle/sparse/layer/__init__.py +++ b/python/paddle/sparse/layer/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from .activation import ReLU +from .norm import BatchNorm from .conv import Conv3D from .conv import SubmConv3D diff --git a/python/paddle/sparse/layer/norm.py b/python/paddle/sparse/layer/norm.py new file mode 100644 index 0000000000000..83b738a5dc354 --- /dev/null +++ b/python/paddle/sparse/layer/norm.py @@ -0,0 +1,160 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import warnings + + +class BatchNorm(paddle.nn.BatchNorm1D): + r""" + Applies Batch Normalization over a SparseCooTensor as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift . + + When use_global_stats = False, the :math:`\mu_{\beta}` + and :math:`\sigma_{\beta}^{2}` are the statistics of one mini-batch. + Calculated as follows: + + .. math:: + + \mu_{\beta} &\gets \frac{1}{m} \sum_{i=1}^{m} x_i \qquad &//\ + \ mini-batch\ mean \\ + \sigma_{\beta}^{2} &\gets \frac{1}{m} \sum_{i=1}^{m}(x_i - \ + \mu_{\beta})^2 \qquad &//\ mini-batch\ variance \\ + + When use_global_stats = True, the :math:`\mu_{\beta}` + and :math:`\sigma_{\beta}^{2}` are not the statistics of one mini-batch. + They are global or running statistics (moving_mean and moving_variance). It usually got from the + pre-trained model. Calculated as follows: + + .. math:: + moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global \ mean \\ + moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global \ variance \\ + + The normalization function formula is as follows: + + .. math:: + + \hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\ + y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift + + - :math:`\epsilon` : add a smaller value to the variance to prevent division by zero + - :math:`\gamma` : trainable proportional parameter + - :math:`\beta` : trainable deviation parameter + + Parameters: + num_features(int): Indicate the number of channels of the input ``Tensor``. + momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. + epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. + weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale` + of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm + will create ParamAttr as weight_attr. If it is set to Fasle, the weight is not learnable. + If the Initializer of the weight_attr is not set, the parameter is initialized with Xavier. Default: None. + bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of batch_norm. + If it is set to None or one attribute of ParamAttr, batch_norm + will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. + If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. + data_format(str, optional): Specify the input data format, may be "NC", "NCL" or "NLC". Defalut "NCL". + use_global_stats(bool|None, optional): Whether to use global mean and variance. If set to False, use the statistics of one mini-batch, if set to True, use the global statistics, if set to None, use global statistics in the test phase and use the statistics of one mini-batch in the training phase. Default: None. + name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. + + Shape: + - x: A SparseCooTensor with layout = 'NDHWC'. + - output: SparseCooTensor with same shape as input x. + + Returns: + None. + + + Examples: + .. code-block:: python + + import paddle + from paddle.fluid.framework import _test_eager_guard + + with _test_eager_guard(): + paddle.seed(123) + channels = 3 + x_data = paddle.randn((1, 6, 6, 6, channels)).astype('float32') + dense_x = paddle.to_tensor(x_data) + sparse_x = dense_x.to_sparse_coo(4) + batch_norm = paddle.sparse.BatchNorm(channels) + batch_norm_out = batch_norm(sparse_x) + print(batch_norm_out.shape) + # [1, 6, 6, 6, 3] + """ + + def __init__(self, + num_features, + momentum=0.9, + epsilon=1e-05, + weight_attr=None, + bias_attr=None, + data_format='NDHWC', + use_global_stats=None, + name=None): + super(BatchNorm, self).__init__( + num_features, + momentum=momentum, + epsilon=epsilon, + weight_attr=weight_attr, + bias_attr=bias_attr, + data_format=data_format, + use_global_stats=use_global_stats, + name=name) + + def _check_data_format(self, input): + if input != "NDHWC": + raise ValueError('sparse BatchNorm only support layout of "NDHWC"') + + def forward(self, input): + values = input.values() + self._check_data_format(self._data_format) + + if len(values.shape) != 2: + raise ValueError('expected 2D input.values() (got {}D)'.format( + len(values.shape))) + + if self.training: + warnings.warn( + "When training, we now always track global mean and variance.") + + batch_norm_out = paddle.nn.functional.batch_norm( + values, + self._mean, + self._variance, + weight=self.weight, + bias=self.bias, + training=self.training, + momentum=self._momentum, + epsilon=self._epsilon, + data_format='NC', + use_global_stats=self._use_global_stats) + + return paddle.sparse.sparse_coo_tensor( + input.indices(), + batch_norm_out, + shape=input.shape, + stop_gradient=input.stop_gradient)