Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API: Sparse Convolution3D #41434

Merged
merged 25 commits into from Apr 15, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc
Expand Up @@ -140,16 +140,16 @@ void Conv3dGradCPUKernel(const CPUContext& dev_ctx,
T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels;
T* tmp_out_grad_ptr = out_grad_features_ptr + offsets[i] * out_channels;
const T* tmp_kernel_ptr = kernel_ptr + i * in_channels * out_channels;
T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * out_channels;
T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * in_channels;
T* tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels;

// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
blas.GEMM(CblasTrans,
CblasNoTrans,
M,
N,
K,
N,
M,
static_cast<T>(1),
tmp_in_ptr,
tmp_out_grad_ptr,
Expand Down
13 changes: 8 additions & 5 deletions paddle/phi/kernels/sparse/cpu/convolution_kernel.cc
Expand Up @@ -51,16 +51,19 @@ void Conv3dCPUKernel(const CPUContext& dev_ctx,
kernel_sizes[i] = kernel_dims[i];
}

phi::funcs::sparse::GetOutShape(
x_dims, kernel_sizes, paddings, dilations, strides, &out_dims);
const int in_channels = kernel_dims[3];
const int out_channels = kernel_dims[4];

std::vector<int> subm_paddings(paddings), subm_strides(strides);
if (subm) {
// the out shape of subm_conv is same as input shape
// reset the padding=kernel_size/2 and strides=1
phi::funcs::sparse::ResetSubmKernelSizeAndStrides(
kernel.dims(), &subm_paddings, &subm_strides);
}

phi::funcs::sparse::GetOutShape(
x_dims, kernel_sizes, subm_paddings, dilations, subm_strides, &out_dims);
const int in_channels = kernel_dims[3];
const int out_channels = kernel_dims[4];

// Second algorithm:
// https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf
// 1. product rulebook
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu
Expand Up @@ -172,16 +172,16 @@ void Conv3dGradGPUKernel(const GPUContext& dev_ctx,
T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels;
T* tmp_out_grad_ptr = out_grad_features_ptr + offsets[i] * out_channels;
const T* tmp_kernel_ptr = kernel_ptr + i * in_channels * out_channels;
T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * out_channels;
T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * in_channels;
T* tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels;

// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
blas.GEMM(CblasTrans,
CblasNoTrans,
M,
N,
K,
N,
M,
static_cast<T>(1),
tmp_in_ptr,
tmp_out_grad_ptr,
Expand Down
16 changes: 10 additions & 6 deletions paddle/phi/kernels/sparse/gpu/convolution_kernel.cu
Expand Up @@ -46,8 +46,17 @@ void Conv3dGPUKernel(const GPUContext& dev_ctx,
for (int i = 0; i < kernel_dims.size(); i++) {
kernel_sizes[i] = kernel_dims[i];
}

std::vector<int> subm_paddings(paddings), subm_strides(strides);
if (subm) {
// the out shape of subm_conv is same as input shape
// reset the padding=kernel_size/2 and strides=1
phi::funcs::sparse::ResetSubmKernelSizeAndStrides(
kernel.dims(), &subm_paddings, &subm_strides);
}

phi::funcs::sparse::GetOutShape(
x_dims, kernel_sizes, paddings, dilations, strides, &out_dims);
x_dims, kernel_sizes, subm_paddings, dilations, subm_strides, &out_dims);
const int in_channels = kernel_dims[3];
const int out_channels = kernel_dims[4];
std::vector<int> offsets(kernel_size + 1), h_counter(kernel_size);
Expand All @@ -65,11 +74,6 @@ void Conv3dGPUKernel(const GPUContext& dev_ctx,
DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta));
DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta));

std::vector<int> subm_paddings(paddings), subm_strides(strides);
if (subm) {
phi::funcs::sparse::ResetSubmKernelSizeAndStrides(
kernel.dims(), &subm_paddings, &subm_strides);
}
int n = ProductRuleBook<T, GPUContext, IntT>(dev_ctx,
x,
kernel_sizes,
Expand Down
66 changes: 66 additions & 0 deletions python/paddle/fluid/dygraph/varbase_patch_methods.py
Expand Up @@ -871,6 +871,28 @@ def pin_memory(self):

@framework.dygraph_only
def values(self):
"""
**Notes**:
**This API is ONLY available in Dygraph mode**

Get the values of current SparseTensor(COO or CSR).

Returns:
Tensor: A DenseTensor

Examples:
.. code-block:: python

import paddle
from paddle.fluid.framework import _test_eager_guard
zkh2016 marked this conversation as resolved.
Show resolved Hide resolved

with _test_eager_guard():
indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
values = [1, 2, 3, 4, 5]
dense_shape = [3, 4]
sparse_x = paddle.sparse.sparse_coo_tensor(paddle.to_tensor(indices, dtype='int32'), paddle.to_tensor(values, dtype='float32'), dense_shape=dense_shape)
print(sparse_x.values())
zkh2016 marked this conversation as resolved.
Show resolved Hide resolved
"""
if self.is_sparse_coo():
return _C_ops.final_state_sparse_coo_values(self)
elif self.is_sparse_csr():
Expand All @@ -881,6 +903,29 @@ def values(self):

@framework.dygraph_only
def to_dense(self):
"""
**Notes**:
**This API is ONLY available in Dygraph mode**

Convert the current SparseTensor(COO or CSR) to DenseTensor.

Returns:
Tensor: A DenseTensor

Examples:
.. code-block:: python

import paddle
from paddle.fluid.framework import _test_eager_guard
zkh2016 marked this conversation as resolved.
Show resolved Hide resolved

with _test_eager_guard():
indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
values = [1, 2, 3, 4, 5]
dense_shape = [3, 4]
sparse_x = paddle.sparse.sparse_coo_tensor(paddle.to_tensor(indices, dtype='int32'), paddle.to_tensor(values, dtype='float32'), dense_shape=dense_shape)
dense_x = sparse_x.to_dense()

"""
if self.is_sparse_coo():
return _C_ops.final_state_sparse_coo_to_dense(self)
elif self.is_sparse_csr():
Expand All @@ -890,6 +935,27 @@ def to_dense(self):

@framework.dygraph_only
zkh2016 marked this conversation as resolved.
Show resolved Hide resolved
def to_sparse_coo(self, sparse_dim):
"""
**Notes**:
**This API is ONLY available in Dygraph mode**

Convert the current DenseTensor to SparseTensor in COO format.

Returns:
Tensor: A SparseCooTensor

Examples:
.. code-block:: python

import paddle
from paddle.fluid.framework import _test_eager_guard
zkh2016 marked this conversation as resolved.
Show resolved Hide resolved

with _test_eager_guard():
dense_x = [[0, 1, 0, 2], [0, 0, 3, 4]]
dense_x = paddle.to_tensor(dense_x, dtype='float32')
sparse_x = dense_x.to_sparse_coo(sparse_dim=2)
"""

if self.is_sparse_csr():
return _C_ops.final_state_sparse_to_sparse_coo(self, sparse_dim)
elif self.is_sparse_coo():
Expand Down
76 changes: 69 additions & 7 deletions python/paddle/fluid/tests/unittests/test_sparse_conv_op.py
Expand Up @@ -40,14 +40,76 @@ def test_conv3d(self):
correct_out_values = [[4], [10]]
sparse_input = core.eager.sparse_coo_tensor(indices, values,
dense_shape, False)
out = _C_ops.final_state_sparse_conv3d(sparse_input, dense_kernel,
paddings, dilations, strides,
1, False)
out = paddle.sparse.functional.conv3d(
sparse_input,
dense_kernel,
bias=None,
stride=strides,
padding=paddings,
dilation=dilations,
groups=1,
data_format="NDHWC")
out.backward(out)
#At present, only backward can be verified to work normally
#TODO(zhangkaihuo): compare the result with dense conv
print(sparse_input.grad.values())
assert np.array_equal(correct_out_values, out.values().numpy())

def test_subm_conv3d(self):
with _test_eager_guard():
indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values = [[1], [2], [3], [4]]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
dense_shape = [1, 1, 3, 4, 1]
sparse_x = paddle.sparse.sparse_coo_tensor(
indices, values, dense_shape, stop_gradient=True)
weight = paddle.randn((1, 3, 3, 1, 1), dtype='float32')
y = paddle.sparse.functional.subm_conv3d(sparse_x, weight)
assert np.array_equal(sparse_x.indices().numpy(),
y.indices().numpy())

def test_Conv3D(self):
with _test_eager_guard():
#(4, non_zero_num), 4-D:(N, D, H, W)
indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
#(non_zero_num, C)
values = [[1], [2], [3], [4]]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
dense_shape = [1, 1, 3, 4, 1]
correct_out_values = [[4], [10]]
sparse_input = paddle.sparse.sparse_coo_tensor(indices, values,
dense_shape, False)

sparse_conv3d = paddle.sparse.Conv3D(
1, 1, (1, 3, 3), data_format='NDHWC')
sparse_out = sparse_conv3d(sparse_input)
#test errors
with self.assertRaises(ValueError):
#Currently, only support data_format='NDHWC'
conv3d = paddle.sparse.SubmConv3D(
1, 1, (1, 3, 3), data_format='NCDHW')

def test_SubmConv3D(self):
with _test_eager_guard():
indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values = [[1], [2], [3], [4]]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
dense_shape = [1, 1, 3, 4, 1]
correct_out_values = [[4], [10]]
sparse_input = paddle.sparse.sparse_coo_tensor(indices, values,
dense_shape, False)

subm_conv3d = paddle.sparse.SubmConv3D(
1, 1, (1, 3, 3), data_format='NDHWC')
# test extra_repr
print(subm_conv3d.extra_repr())

sparse_out = subm_conv3d(sparse_input)
# the output shape of subm_conv is same as input shape
assert np.array_equal(indices, sparse_out.indices().numpy())

#TODO: Add more test case
#test errors
with self.assertRaises(ValueError):
#Currently, only support data_format='NDHWC'
conv3d = paddle.sparse.SubmConv3D(
1, 1, (1, 3, 3), data_format='NCDHW')
6 changes: 5 additions & 1 deletion python/paddle/sparse/__init__.py
Expand Up @@ -15,5 +15,9 @@
from .creation import sparse_coo_tensor
from .creation import sparse_csr_tensor
from .layer.activation import ReLU
from .layer.conv import Conv3D
from .layer.conv import SubmConv3D

__all__ = ['sparse_coo_tensor', 'sparse_csr_tensor', 'ReLU']
__all__ = [
'sparse_coo_tensor', 'sparse_csr_tensor', 'ReLU', 'Conv3D', 'SubmConv3D'
]
4 changes: 3 additions & 1 deletion python/paddle/sparse/functional/__init__.py
Expand Up @@ -13,5 +13,7 @@
# limitations under the License.

from .activation import relu # noqa: F401
from .conv import conv3d # noqa: F401
from .conv import subm_conv3d # noqa: F401

__all__ = ['relu']
__all__ = ['relu', 'conv3d', 'subm_conv3d']