Skip to content

Commit

Permalink
Fix hybrid sparse COO tensor conversion to meta tensor
Browse files Browse the repository at this point in the history
ghstack-source-id: 06d87bd0b180ad2429177ef484e9cf5bae432eea
Pull Request resolved: #125120
  • Loading branch information
pearu committed Apr 29, 2024
1 parent 3d1dd79 commit 5509d67
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 43 deletions.
32 changes: 31 additions & 1 deletion aten/src/ATen/native/TensorConversions.cpp
Expand Up @@ -254,7 +254,37 @@ Tensor _to_copy(

// TODO: Use the dispatcher for this.
// Currently there are unenumerated extensibility issues preventing this.
if (at::sparse_csr::is_sparse_compressed(self)) {
if (self.layout() == kSparse) {
TORCH_CHECK(
memory_format == MemoryFormat::Preserve,
"to(options): COO only supports memory format Preserve, but got ", memory_format,
" instead.");
auto indices = self._indices();
const auto new_indices = at::native::to(
indices,
indices.scalar_type(),
c10::kStrided,
device,
pin_memory,
non_blocking,
true, // force copy since we are in _to_copy
memory_format);
const auto new_values = at::native::to(
self._values(),
dtype,
c10::kStrided,
device,
pin_memory,
non_blocking,
true, // force copy since we are in _to_copy
memory_format);

return at::_sparse_coo_tensor_unsafe(
new_indices,
new_values,
self.sizes(),
options, self.is_coalesced());
} else if (at::sparse_csr::is_sparse_compressed(self)) {
TORCH_CHECK(
memory_format == MemoryFormat::Preserve,
"to(options): ", at::sparse_csr::layoutToString(self.layout()),
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/native/sparse/SparseTensorMath.cpp
Expand Up @@ -591,8 +591,9 @@ SparseTensor& add_out_sparse_cpu(const SparseTensor& t, const SparseTensor& src,
TORCH_CHECK(is_same_density(t, src), "add: expected 'self' and 'other' to have same density, but 'self' has ", t.sparse_dim(), " sparse dimensions while 'other' has ", src.sparse_dim(), " sparse dimensions");

r.resize_as_(src);

if (src._values().is_contiguous() && t._values().is_contiguous()) {
if (r.is_meta()) {
return r;
} else if (src._values().is_contiguous() && t._values().is_contiguous()) {
return add_out_sparse_contiguous(r, t, src, value, commonDtype);
} else {
return add_out_sparse_non_contiguous(r, t, src, value, commonDtype);
Expand Down
97 changes: 57 additions & 40 deletions test/test_sparse.py
Expand Up @@ -4381,6 +4381,60 @@ def test_print_meta(self, dtype, layout):
self.maxDiff = orig_maxDiff
raise

def assertEqualMeta(self, x, y, expected_nnz):
self.assertEqual(x.layout, y.layout)
self.assertEqual(x.shape, y.shape)
self.assertEqual(x.dtype, y.dtype)
self.assertEqual(x.sparse_dim(), y.sparse_dim())
self.assertEqual(x.dense_dim(), y.dense_dim())

def assertEqualAttrs(x, y, expected_shape):
self.assertEqual(x.shape, expected_shape)
self.assertEqual(x.dtype, y.dtype)
self.assertEqual(x.layout, y.layout)
if not x.is_meta:
self.assertEqual(x.device, y.device)

if x.layout is torch.sparse_coo:
assertEqualAttrs(x._indices(), y._indices(), (*y._indices().shape[:-1], expected_nnz))
assertEqualAttrs(x._values(), y._values(), (expected_nnz, *y._values().shape[1:]))
elif x.layout in {torch.sparse_csr, torch.sparse_bsr}:
assertEqualAttrs(x.crow_indices(), y.crow_indices(), y.crow_indices().shape)
assertEqualAttrs(x.col_indices(), y.col_indices(), (*y.col_indices().shape[:-1], expected_nnz))
batch_dim = x.col_indices().ndim - 1
values_shape = (*y.values().shape[:batch_dim], expected_nnz, *y.values().shape[batch_dim + 1:])
self.assertEqual(x.values().layout, y.values().layout)
self.assertEqual(x.values().dtype, y.values().dtype)
self.assertEqual(x.values().shape, values_shape)
elif x.layout in {torch.sparse_csc, torch.sparse_bsc}:
assertEqualAttrs(x.ccol_indices(), y.ccol_indices(), y.ccol_indices().shape)
assertEqualAttrs(x.row_indices(), y.row_indices(), (*y.row_indices().shape[:-1], expected_nnz))
batch_dim = x.row_indices().ndim - 1
values_shape = (*y.values().shape[:batch_dim], expected_nnz, *y.values().shape[batch_dim + 1:])
self.assertEqual(x.values().layout, y.values().layout)
self.assertEqual(x.values().dtype, y.values().dtype)
self.assertEqual(x.values().shape, values_shape)

@all_sparse_layouts('layout', include_strided=False)
@parametrize("dtype", [torch.float64])
def test_to_meta(self, dtype, layout):
index_dtype = torch.int64
device = 'cpu'
for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
m = t.to(device="meta")
self.assertEqual(m.device.type, "meta")
self.assertEqualMeta(m, t, t._nnz())

@all_sparse_layouts('layout', include_strided=False)
@parametrize("dtype", [torch.float64])
def test_zeros_like_meta(self, dtype, layout):
index_dtype = torch.int64
device = 'cpu'
for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
m = torch.zeros_like(t, device="meta")
self.assertEqual(m.device.type, "meta")
self.assertEqualMeta(m, t, 0)

@all_sparse_layouts('layout', include_strided=False)
@parametrize("dtype", [torch.float64])
def test_fake(self, dtype, layout):
Expand All @@ -4391,45 +4445,7 @@ def test_fake(self, dtype, layout):
for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
f = FakeTensor.from_tensor(t, fake_mode)
self.assertIsInstance(f, FakeTensor)
self.assertEqual(f.layout, layout)
self.assertEqual(f.shape, t.shape)
self.assertEqual(f.device, t.device)
if layout is torch.sparse_coo:
nnz = 0
indices = f._indices()
self.assertEqual(indices.dtype, index_dtype)
self.assertEqual(indices.device, t.device)
self.assertEqual(indices.shape, (*t._indices().shape[:-1], nnz))
values = f._values()
self.assertEqual(values.dtype, dtype)
self.assertEqual(values.device, t.device)
self.assertEqual(values.shape, (nnz, *t._values().shape[1:]))
else:
nnz = 0
if layout in {torch.sparse_csr, torch.sparse_bsr}:
f_compressed_indices, f_plain_indices = f.crow_indices(), f.col_indices()
compressed_indices, plain_indices = t.crow_indices(), t.col_indices()
else:
f_compressed_indices, f_plain_indices = f.ccol_indices(), f.row_indices()
compressed_indices, plain_indices = t.ccol_indices(), t.row_indices()
f_values = f.values()
values = t.values()
batch_dims = len(compressed_indices.shape) - 1
self.assertEqual(f_compressed_indices.layout, compressed_indices.layout)
self.assertEqual(f_compressed_indices.shape, compressed_indices.shape)
self.assertEqual(f_compressed_indices.dtype, compressed_indices.dtype)
self.assertEqual(f_compressed_indices.device, compressed_indices.device)

self.assertEqual(f_plain_indices.layout, plain_indices.layout)
self.assertEqual(f_plain_indices.shape, (*plain_indices.shape[:-1], nnz))
self.assertEqual(f_plain_indices.dtype, plain_indices.dtype)
self.assertEqual(f_plain_indices.device, plain_indices.device)

batch_dim = plain_indices.ndim - 1
self.assertEqual(f_values.layout, values.layout)
self.assertEqual(f_values.shape, (*values.shape[:batch_dim], nnz, *values.shape[batch_dim + 1:]))
self.assertEqual(f_values.dtype, values.dtype)
self.assertEqual(f_values.device, values.device)
self.assertEqualMeta(f, t, 0)

@all_sparse_layouts('layout', include_strided=False)
@parametrize("dtype", [torch.float64])
Expand Down Expand Up @@ -4464,9 +4480,10 @@ def test_add_meta(self, dtype, layout):
device = 'cpu'
index_dtype = torch.int64
for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
expected = torch.add(t, t).to(device='meta')
m = t.to(device='meta')
r = torch.add(m, m)
self.assertEqual(r, m)
self.assertEqualMeta(r, expected, 0 if layout is torch.sparse_coo else expected._nnz())


class _SparseDataset(torch.utils.data.Dataset):
Expand Down

0 comments on commit 5509d67

Please sign in to comment.