Skip to content

Commit

Permalink
Fix hybrid sparse COO tensor conversion to meta tensor
Browse files Browse the repository at this point in the history
ghstack-source-id: 2ce0d6b099107ca5c3d99ab4ef2d617cce0836b0
Pull Request resolved: #125120
  • Loading branch information
pearu committed Apr 28, 2024
1 parent 94b328e commit f7f7a05
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 3 deletions.
32 changes: 31 additions & 1 deletion aten/src/ATen/native/TensorConversions.cpp
Expand Up @@ -254,7 +254,37 @@ Tensor _to_copy(

// TODO: Use the dispatcher for this.
// Currently there are unenumerated extensibility issues preventing this.
if (at::sparse_csr::is_sparse_compressed(self)) {
if (self.layout() == kSparse) {
TORCH_CHECK(
memory_format == MemoryFormat::Preserve,
"to(options): COO only supports memory format Preserve, but got ", memory_format,
" instead.");
auto indices = self._indices();
const auto new_indices = at::native::to(
indices,
indices.scalar_type(),
c10::kStrided,
device,
pin_memory,
non_blocking,
true, // force copy since we are in _to_copy
memory_format);
const auto new_values = at::native::to(
self._values(),
dtype,
c10::kStrided,
device,
pin_memory,
non_blocking,
true, // force copy since we are in _to_copy
memory_format);

return at::_sparse_coo_tensor_unsafe(
new_indices,
new_values,
self.sizes(),
options, self.is_coalesced());
} else if (at::sparse_csr::is_sparse_compressed(self)) {
TORCH_CHECK(
memory_format == MemoryFormat::Preserve,
"to(options): ", at::sparse_csr::layoutToString(self.layout()),
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/native/sparse/SparseTensorMath.cpp
Expand Up @@ -591,8 +591,9 @@ SparseTensor& add_out_sparse_cpu(const SparseTensor& t, const SparseTensor& src,
TORCH_CHECK(is_same_density(t, src), "add: expected 'self' and 'other' to have same density, but 'self' has ", t.sparse_dim(), " sparse dimensions while 'other' has ", src.sparse_dim(), " sparse dimensions");

r.resize_as_(src);

if (src._values().is_contiguous() && t._values().is_contiguous()) {
if (r.is_meta()) {
return r;
} else if (src._values().is_contiguous() && t._values().is_contiguous()) {
return add_out_sparse_contiguous(r, t, src, value, commonDtype);
} else {
return add_out_sparse_non_contiguous(r, t, src, value, commonDtype);
Expand Down
31 changes: 31 additions & 0 deletions test/test_sparse.py
Expand Up @@ -4381,6 +4381,37 @@ def test_print_meta(self, dtype, layout):
self.maxDiff = orig_maxDiff
raise

@all_sparse_layouts('layout', include_strided=False)
@parametrize("dtype", [torch.float64])
def test_to_meta(self, dtype, layout):
index_dtype = torch.int64
device = 'cpu'
for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
m = t.to(device="meta")
self.assertEqual(m.device.type, "meta")
self.assertEqual(m, t, exact_device=False, exact_layout=True)
self.assertEqual(m.sparse_dim(), t.sparse_dim())
self.assertEqual(m.dense_dim(), t.dense_dim())

if layout is torch.sparse_coo:
self.assertEqual(m._indices().device.type, "meta")
self.assertEqual(m._indices(), t._indices(), exact_device=False, exact_layout=True)
self.assertEqual(m._values().device.type, "meta")
self.assertEqual(m._values(), t._values(), exact_device=False, exact_layout=True)
else:
if layout in {torch.sparse_csr, torch.sparse_bsr}:
m_compressed_indices, m_plain_indices = m.crow_indices(), m.col_indices()
compressed_indices, plain_indices = t.crow_indices(), t.col_indices()
else:
m_compressed_indices, m_plain_indices = m.ccol_indices(), m.row_indices()
compressed_indices, plain_indices = t.ccol_indices(), t.row_indices()
self.assertEqual(m_compressed_indices.device.type, "meta")
self.assertEqual(m_compressed_indices, compressed_indices, exact_device=False, exact_layout=True)
self.assertEqual(m_plain_indices.device.type, "meta")
self.assertEqual(m_plain_indices, plain_indices, exact_device=False, exact_layout=True)
self.assertEqual(m.values().device.type, "meta")
self.assertEqual(m.values(), t.values(), exact_device=False, exact_layout=True)

@all_sparse_layouts('layout', include_strided=False)
@parametrize("dtype", [torch.float64])
def test_fake(self, dtype, layout):
Expand Down

0 comments on commit f7f7a05

Please sign in to comment.