Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sparse BatchNorm and fix two bugs #42013

Merged
merged 27 commits into from Apr 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/phi/kernels/sparse/cpu/coalesced_kernel.cc
Expand Up @@ -44,7 +44,7 @@ void CoalescedCPUKernel(const CPUContext& dev_ctx,

const T* x_values_ptr = x_values.data<T>();
const int64_t stride =
x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim;
x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1];

std::map<IntT, std::vector<int64_t>> indices_to_index;
for (uint64_t i = 0; i < x_indexs.size(); i++) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc
Expand Up @@ -125,7 +125,7 @@ void SparseMaskHelperCPUKernel(const CPUContext& dev_ctx,
T* out_ptr = out->data<T>();
memset(out_ptr, static_cast<T>(0), out->numel() * sizeof(T));
const int64_t stride =
x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim;
x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1];
const T* in_ptr = x.non_zero_elements().data<T>();
// TODO(zhangkaihuo): multithreading can be used for acceleration
for (uint64_t i = 0; i < mask_indexs.size(); i++) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu
Expand Up @@ -76,7 +76,7 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx,
// 2. get the address of each non-zero values
const T* x_values_ptr = x_values.data<T>();
const int64_t stride =
x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim;
x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1];
DenseTensor values_indexs = phi::Empty(
dev_ctx, DenseTensorMeta(DataType::INT32, {nnz}, DataLayout::NCHW));
int* values_indexs_ptr = values_indexs.data<int>();
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu
Expand Up @@ -231,7 +231,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
T* out_ptr = out->data<T>();

const int64_t stride =
x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim;
x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1];

SparseMaskCopyKernel<<<config.block_per_grid,
config.thread_per_block,
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/fluid/tests/unittests/test_sparse_conv_op.py
Expand Up @@ -31,19 +31,21 @@ def test_conv3d(self):
paddings = [0, 0, 0]
strides = [1, 1, 1]
dilations = [1, 1, 1]
bias = [1]

indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values = [1, 2, 3, 4]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
dense_shape = [1, 1, 3, 4, 1]
correct_out_values = [[4], [10]]
correct_out_values = [[5], [11]]
sparse_input = core.eager.sparse_coo_tensor(indices, values,
dense_shape, False)
out = paddle.sparse.functional.conv3d(
sparse_input,
dense_kernel,
bias=None,
bias=paddle.to_tensor(
bias, dtype='float32'),
stride=strides,
padding=paddings,
dilation=dilations,
Expand Down
87 changes: 87 additions & 0 deletions python/paddle/fluid/tests/unittests/test_sparse_norm_op.py
@@ -0,0 +1,87 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import unittest
import numpy as np
import paddle
from paddle.fluid.framework import _test_eager_guard
import copy


class TestSparseBatchNorm(unittest.TestCase):
def test(self):
with _test_eager_guard():
paddle.seed(0)
channels = 4
shape = [2, 3, 6, 6, channels]
#there is no zero in dense_x
dense_x = paddle.randn(shape)
dense_x.stop_gradient = False

batch_norm = paddle.nn.BatchNorm3D(channels, data_format="NDHWC")
dense_y = batch_norm(dense_x)
dense_y.backward(dense_y)

sparse_dim = 4
dense_x2 = copy.deepcopy(dense_x)
dense_x2.stop_gradient = False
sparse_x = dense_x2.to_sparse_coo(sparse_dim)
sparse_batch_norm = paddle.sparse.BatchNorm(channels)
# set same params
sparse_batch_norm._mean.set_value(batch_norm._mean)
sparse_batch_norm._variance.set_value(batch_norm._variance)
sparse_batch_norm.weight.set_value(batch_norm.weight)

sparse_y = sparse_batch_norm(sparse_x)
# compare the result with dense batch_norm
assert np.allclose(
dense_y.flatten().numpy(),
sparse_y.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)

# test backward
sparse_y.backward(sparse_y)
assert np.allclose(
dense_x.grad.flatten().numpy(),
sparse_x.grad.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)

def test_error_layout(self):
with _test_eager_guard():
with self.assertRaises(ValueError):
shape = [2, 3, 6, 6, 3]
x = paddle.randn(shape)
sparse_x = x.to_sparse_coo(4)
sparse_batch_norm = paddle.sparse.BatchNorm(
3, data_format='NCDHW')
sparse_batch_norm(sparse_x)

def test2(self):
with _test_eager_guard():
paddle.seed(123)
channels = 3
x_data = paddle.randn((1, 6, 6, 6, channels)).astype('float32')
dense_x = paddle.to_tensor(x_data)
sparse_x = dense_x.to_sparse_coo(4)
batch_norm = paddle.sparse.BatchNorm(channels)
batch_norm_out = batch_norm(sparse_x)
print(batch_norm_out.shape)
# [1, 6, 6, 6, 3]


if __name__ == "__main__":
unittest.main()
39 changes: 39 additions & 0 deletions python/paddle/fluid/tests/unittests/test_sparse_utils_op.py
Expand Up @@ -208,6 +208,20 @@ def test_coo_values_grad(self):
# test coo_values_grad
values_tensor.backward(paddle.to_tensor(out_grad))
assert np.array_equal(out_grad, sparse_x.grad.values().numpy())
indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
values = [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0],
[5.0, 5.0]]
sparse_x = paddle.sparse.sparse_coo_tensor(
paddle.to_tensor(indices),
paddle.to_tensor(values),
shape=[3, 4, 2],
stop_gradient=False)
values_tensor = sparse_x.values()
out_grad = [[2.0, 2.0], [3.0, 3.0], [5.0, 5.0], [8.0, 8.0],
[9.0, 9.0]]
# test coo_values_grad
values_tensor.backward(paddle.to_tensor(out_grad))
assert np.array_equal(out_grad, sparse_x.grad.values().numpy())

def test_sparse_coo_tensor_grad(self):
with _test_eager_guard():
Expand All @@ -233,6 +247,21 @@ def test_sparse_coo_tensor_grad(self):
assert np.array_equal(correct_values_grad,
values.grad.numpy())

# test the non-zero values is a vector
values = [[1, 1], [2, 2]]
values = paddle.to_tensor(
values, dtype='float32', stop_gradient=False)
sparse_x = paddle.sparse.sparse_coo_tensor(
indices, values, shape=[2, 2, 2], stop_gradient=False)
grad_values = [[2, 2], [3, 3]]
grad_values = paddle.to_tensor(grad_values, dtype='float32')
sparse_out_grad = paddle.sparse.sparse_coo_tensor(
grad_indices, grad_values, shape=[2, 2, 2])
sparse_x.backward(sparse_out_grad)
correct_values_grad = [[0, 0], [3, 3]]
assert np.array_equal(correct_values_grad,
values.grad.numpy())

def test_sparse_coo_tensor_sorted(self):
with _test_eager_guard():
for device in devices:
Expand All @@ -252,6 +281,16 @@ def test_sparse_coo_tensor_sorted(self):
assert np.array_equal(values_sorted,
sparse_x.values().numpy())

# test the non-zero values is a vector
values = [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]
values = paddle.to_tensor(values, dtype='float32')
sparse_x = paddle.sparse.sparse_coo_tensor(indices, values)
values_sorted = [[5.0, 5.0], [1.0, 1.0]]
assert np.array_equal(indices_sorted,
sparse_x.indices().numpy())
assert np.array_equal(values_sorted,
sparse_x.values().numpy())


class TestCooError(unittest.TestCase):
def test_small_shape(self):
Expand Down
5 changes: 4 additions & 1 deletion python/paddle/sparse/__init__.py
Expand Up @@ -15,9 +15,12 @@
from .creation import sparse_coo_tensor
from .creation import sparse_csr_tensor
from .layer.activation import ReLU
from .layer.norm import BatchNorm

from .layer.conv import Conv3D
from .layer.conv import SubmConv3D

__all__ = [
'sparse_coo_tensor', 'sparse_csr_tensor', 'ReLU', 'Conv3D', 'SubmConv3D'
'sparse_coo_tensor', 'sparse_csr_tensor', 'ReLU', 'Conv3D', 'SubmConv3D',
'BatchNorm'
]
14 changes: 10 additions & 4 deletions python/paddle/sparse/creation.py
Expand Up @@ -20,6 +20,8 @@
from ..tensor import max
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype

import numpy as np

__all__ = [
'sparse_coo_tensor',
'sparse_csr_tensor',
Expand All @@ -33,11 +35,14 @@ def _handle_dtype(data, dtype):
return data


def _infer_dense_shape(indices):
def _infer_dense_shape(indices, values):
assert len(indices.shape) == 2
lens = max(indices, axis=1)
lens = lens + 1
return list(lens.numpy())
lens = lens.numpy()
if len(values.shape) > 1:
lens = np.append(lens, values.shape[1:])
return list(lens)


def _get_place(place):
Expand Down Expand Up @@ -106,7 +111,7 @@ def sparse_coo_tensor(indices,
with _test_eager_guard():
indices = [[0, 1, 2], [1, 2, 0]]
values = [1.0, 2.0, 3.0]
dense_shape = [2, 3]
dense_shape = [3, 3]
coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
# print(coo)
# Tensor(shape=[2, 3], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True,
Expand Down Expand Up @@ -145,7 +150,8 @@ def sparse_coo_tensor(indices,
values = _handle_dtype(values, dtype)
values.stop_gradient = stop_gradient

min_shape = _infer_dense_shape(indices)
min_shape = _infer_dense_shape(indices, values)

if shape is None:
shape = min_shape
else:
Expand Down
17 changes: 14 additions & 3 deletions python/paddle/sparse/functional/conv.py
Expand Up @@ -16,6 +16,8 @@

from paddle import _C_ops, in_dynamic_mode
from ...fluid.layers.utils import convert_to_list
from ...fluid.layers.nn import elementwise_add
from .. import sparse_coo_tensor
from paddle.nn.functional.conv import _update_padding_nd


Expand All @@ -30,7 +32,6 @@ def _conv3d(x,
data_format="NDHWC",
name=None):
assert in_dynamic_mode(), "Currently, only support dynamic mode"
assert bias == None, "Currently, sparse_conv3d does not support bias"
assert groups == 1, "Currently, only support groups=1"

dims = 3
Expand Down Expand Up @@ -61,8 +62,18 @@ def _conv3d(x,
dilation = convert_to_list(dilation, dims, 'dilation')
op_type = "conv3d"

return _C_ops.final_state_sparse_conv3d(x, weight, padding, dilation,
stride, groups, subm)
pre_bias = _C_ops.final_state_sparse_conv3d(x, weight, padding, dilation,
stride, groups, subm)
if bias is not None:
values = pre_bias.values()
add_bias = elementwise_add(values, bias, axis=1)
return sparse_coo_tensor(
pre_bias.indices(),
add_bias,
shape=pre_bias.shape,
stop_gradient=pre_bias.stop_gradient)
else:
return pre_bias


def conv3d(x,
Expand Down
1 change: 1 addition & 0 deletions python/paddle/sparse/layer/__init__.py
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from .activation import ReLU
from .norm import BatchNorm
from .conv import Conv3D
from .conv import SubmConv3D

Expand Down