Skip to content

Commit

Permalink
Nested fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Elias Ellison authored and pytorchmergebot committed Jun 16, 2022
1 parent f965681 commit c534537
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/core/tensor_type.cpp
Expand Up @@ -253,7 +253,7 @@ TensorTypePtr TensorType::create(const at::Tensor& t) {
VaryingShape<size_t> stride_indices;
VaryingShape<int64_t> strides;
VaryingShape<int64_t> sizes;
if (t.layout() == at::kStrided) {
if (t.layout() == at::kStrided && !t.is_nested()) {
sizes = VaryingShape<int64_t>{t.sizes().vec()};
strides = VaryingShape<int64_t>{t.strides().vec()};
return TensorType::create(
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/ir/constants.cpp
Expand Up @@ -12,7 +12,7 @@ namespace jit {
bool insertableTensor(const at::Tensor& ten) {
// bail if tensor has no storage i.e. opaque tensor used in MKLdnn.
// or gradients because we have no way of serializing them & are mutable
return !ten.requires_grad() && ten.has_storage();
return !ten.requires_grad() && ten.has_storage() && !ten.is_nested();
}

bool insertableIValue(const IValue& ivalue) {
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/ir/node_hashing.cpp
Expand Up @@ -23,6 +23,9 @@ bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) {
if (lhs.is_mkldnn() || rhs.is_mkldnn()) {
return false;
}
if (lhs.is_nested() || rhs.is_nested()) {
return false;
}
// If device is not equal, lhs.equal(rhs) would throw an error.
if (lhs.device() != rhs.device()) {
return false;
Expand Down

0 comments on commit c534537

Please sign in to comment.