diff --git a/aten/src/ATen/core/tensor_type.cpp b/aten/src/ATen/core/tensor_type.cpp index 87972825d291..af0fa16db6e0 100644 --- a/aten/src/ATen/core/tensor_type.cpp +++ b/aten/src/ATen/core/tensor_type.cpp @@ -253,7 +253,7 @@ TensorTypePtr TensorType::create(const at::Tensor& t) { VaryingShape stride_indices; VaryingShape strides; VaryingShape sizes; - if (t.layout() == at::kStrided) { + if (t.layout() == at::kStrided && !t.is_nested()) { sizes = VaryingShape{t.sizes().vec()}; strides = VaryingShape{t.strides().vec()}; return TensorType::create( diff --git a/torch/csrc/jit/ir/constants.cpp b/torch/csrc/jit/ir/constants.cpp index 64fcf807b039..bdc2020d17be 100644 --- a/torch/csrc/jit/ir/constants.cpp +++ b/torch/csrc/jit/ir/constants.cpp @@ -12,7 +12,7 @@ namespace jit { bool insertableTensor(const at::Tensor& ten) { // bail if tensor has no storage i.e. opaque tensor used in MKLdnn. // or gradients because we have no way of serializing them & are mutable - return !ten.requires_grad() && ten.has_storage(); + return !ten.requires_grad() && ten.has_storage() && !ten.is_nested(); } bool insertableIValue(const IValue& ivalue) { diff --git a/torch/csrc/jit/ir/node_hashing.cpp b/torch/csrc/jit/ir/node_hashing.cpp index c65eda562458..033c783386f4 100644 --- a/torch/csrc/jit/ir/node_hashing.cpp +++ b/torch/csrc/jit/ir/node_hashing.cpp @@ -23,6 +23,9 @@ bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) { if (lhs.is_mkldnn() || rhs.is_mkldnn()) { return false; } + if (lhs.is_nested() || rhs.is_nested()) { + return false; + } // If device is not equal, lhs.equal(rhs) would throw an error. if (lhs.device() != rhs.device()) { return false;