Skip to content

Commit

Permalink
Symintify embedding_sparse_backward (pytorch#88746)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#88746
Approved by: https://github.com/ezyang
  • Loading branch information
anjali411 authored and pytorchmergebot committed Nov 9, 2022
1 parent b7aa22d commit 1af9b38
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions aten/src/ATen/native/Embedding.cpp
Expand Up @@ -89,20 +89,20 @@ Tensor embedding_sparse_backward(
grad = grad.index(c);
}

int64_t num_features = grad_.size(-1);
auto weight_size = std::array<int64_t, 2>{{ num_weights, num_features }};
auto num_features = grad_.sym_size(-1);
auto weight_size = std::array<c10::SymInt, 2>{{ num_weights, num_features }};
auto dense_options = grad.options();

// check if all our grad come from padding_idx
if (grad.numel() == 0) {
return at::_sparse_coo_tensor_unsafe(at::empty({1, 0}, indices_.options().dtype(kLong)),
at::empty({0, num_features}, dense_options),
if (grad.sym_numel() == 0) {
return at::_sparse_coo_tensor_unsafe_symint(at::empty({1, 0}, indices_.options().dtype(kLong)),
at::empty_symint({c10::SymInt(0), num_features}, dense_options),
weight_size);
}

auto index = indices.reshape({1, -1});
auto values = grad.reshape({-1, num_features});
return at::_sparse_coo_tensor_unsafe(index.to(kLong), values, weight_size);
auto values = grad.reshape_symint({c10::SymInt(-1), num_features});
return at::_sparse_coo_tensor_unsafe_symint(index.to(kLong), values, weight_size);
}

Tensor embedding_dense_backward_cpu(
Expand Down

0 comments on commit 1af9b38

Please sign in to comment.