From 1af9b38a907cdd8a21f4e0a363af3f136fa4062a Mon Sep 17 00:00:00 2001 From: anjali411 Date: Wed, 9 Nov 2022 14:48:20 +0000 Subject: [PATCH] Symintify embedding_sparse_backward (#88746) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88746 Approved by: https://github.com/ezyang --- aten/src/ATen/native/Embedding.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/Embedding.cpp b/aten/src/ATen/native/Embedding.cpp index f23594022991..5972ce0d2404 100644 --- a/aten/src/ATen/native/Embedding.cpp +++ b/aten/src/ATen/native/Embedding.cpp @@ -89,20 +89,20 @@ Tensor embedding_sparse_backward( grad = grad.index(c); } - int64_t num_features = grad_.size(-1); - auto weight_size = std::array{{ num_weights, num_features }}; + auto num_features = grad_.sym_size(-1); + auto weight_size = std::array{{ num_weights, num_features }}; auto dense_options = grad.options(); // check if all our grad come from padding_idx - if (grad.numel() == 0) { - return at::_sparse_coo_tensor_unsafe(at::empty({1, 0}, indices_.options().dtype(kLong)), - at::empty({0, num_features}, dense_options), + if (grad.sym_numel() == 0) { + return at::_sparse_coo_tensor_unsafe_symint(at::empty({1, 0}, indices_.options().dtype(kLong)), + at::empty_symint({c10::SymInt(0), num_features}, dense_options), weight_size); } auto index = indices.reshape({1, -1}); - auto values = grad.reshape({-1, num_features}); - return at::_sparse_coo_tensor_unsafe(index.to(kLong), values, weight_size); + auto values = grad.reshape_symint({c10::SymInt(-1), num_features}); + return at::_sparse_coo_tensor_unsafe_symint(index.to(kLong), values, weight_size); } Tensor embedding_dense_backward_cpu(