diff --git a/test/cpp/api/serialize.cpp b/test/cpp/api/serialize.cpp index 05bb0f941d40..20d572853d3a 100644 --- a/test/cpp/api/serialize.cpp +++ b/test/cpp/api/serialize.cpp @@ -288,6 +288,14 @@ TEST(SerializeTest, MathBits) { ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec()); ASSERT_TRUE(actual.allclose(expected)); } + + { + // We don't support serializing `ZeroTensor` as it is not public facing yet. + // If in future, `ZeroTensor` serialization is supported, this test should + // start failing! + auto t = torch::_efficientzerotensor({5, 5}); + ASSERT_THROWS_WITH(save_and_load(t), "ZeroTensor is not serializable,"); + } } TEST(SerializeTest, BasicToFile) { diff --git a/test/test_serialization.py b/test/test_serialization.py index af0317e87a14..779d6fb5c20c 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -931,6 +931,28 @@ def _save_load_check(t): t_n_c = torch._neg_view(torch.conj(t)) _save_load_check(t_n_c) + @parametrize('weights_only', (False, True)) + def test_serialization_efficient_zerotensor(self, weights_only): + # We don't support serializing `ZeroTensor` as it is not public + # facing yet. + # If in future, `ZeroTensor` serialization is supported, this test + # should start failing! + t = torch._efficientzerotensor((4, 5)) + + def _save_load_check(t): + with BytesIOContext() as f: + torch.save(t, f) + f.seek(0) + # Unsafe load should work + self.assertEqual(torch.load(f, weights_only=weights_only), t) + + # NOTE: `torch.save` fails before we hit the TORCH_CHECK in `getTensoMetadata` + # as nullptr storage is disabled. + err_msg = (r'python bindings to nullptr storage \(e.g., from torch.Tensor._make_wrapper_subclass\)' + ' are currently unsafe and thus disabled') + with self.assertRaisesRegex(RuntimeError, err_msg): + _save_load_check(t) + def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super(TestSerialization, self).run(*args, **kwargs) diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index c289cae12b64..26f9fcf42396 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -300,6 +300,12 @@ bool checkHasValidSetGetState(const std::shared_ptr& cls); // For now, it only takes care of `conj` and `neg` bit. inline std::unordered_map getTensorMetadata( const at::Tensor& t) { + // We don't support serializing `ZeroTensor` as it is not public + // facing yet. + TORCH_CHECK( + !t._is_zerotensor(), + "ZeroTensor is not serializable,", + " please file an issue if required."); std::unordered_map metadata{}; // Only add meta-data if the value is not default.