From c534537fe6ec936c49bb6699a921c6c5b1bdff28 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Mon, 13 Jun 2022 16:58:58 -0700 Subject: [PATCH] Nested fix --- aten/src/ATen/core/tensor_type.cpp | 2 +- torch/csrc/jit/ir/constants.cpp | 2 +- torch/csrc/jit/ir/node_hashing.cpp | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/core/tensor_type.cpp b/aten/src/ATen/core/tensor_type.cpp index 87972825d291..af0fa16db6e0 100644 --- a/aten/src/ATen/core/tensor_type.cpp +++ b/aten/src/ATen/core/tensor_type.cpp @@ -253,7 +253,7 @@ TensorTypePtr TensorType::create(const at::Tensor& t) { VaryingShape stride_indices; VaryingShape strides; VaryingShape sizes; - if (t.layout() == at::kStrided) { + if (t.layout() == at::kStrided && !t.is_nested()) { sizes = VaryingShape{t.sizes().vec()}; strides = VaryingShape{t.strides().vec()}; return TensorType::create( diff --git a/torch/csrc/jit/ir/constants.cpp b/torch/csrc/jit/ir/constants.cpp index 64fcf807b039..bdc2020d17be 100644 --- a/torch/csrc/jit/ir/constants.cpp +++ b/torch/csrc/jit/ir/constants.cpp @@ -12,7 +12,7 @@ namespace jit { bool insertableTensor(const at::Tensor& ten) { // bail if tensor has no storage i.e. opaque tensor used in MKLdnn. // or gradients because we have no way of serializing them & are mutable - return !ten.requires_grad() && ten.has_storage(); + return !ten.requires_grad() && ten.has_storage() && !ten.is_nested(); } bool insertableIValue(const IValue& ivalue) { diff --git a/torch/csrc/jit/ir/node_hashing.cpp b/torch/csrc/jit/ir/node_hashing.cpp index c65eda562458..033c783386f4 100644 --- a/torch/csrc/jit/ir/node_hashing.cpp +++ b/torch/csrc/jit/ir/node_hashing.cpp @@ -23,6 +23,9 @@ bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) { if (lhs.is_mkldnn() || rhs.is_mkldnn()) { return false; } + if (lhs.is_nested() || rhs.is_nested()) { + return false; + } // If device is not equal, lhs.equal(rhs) would throw an error. if (lhs.device() != rhs.device()) { return false;