Skip to content

Commit

Permalink
Error on ZeroTensor serialization (#88803)
Browse files Browse the repository at this point in the history
Follow-up : #88182 (comment)

Pull Request resolved: #88803
Approved by: https://github.com/anjali411
  • Loading branch information
kshitij12345 authored and pytorchmergebot committed Nov 11, 2022
1 parent b843f4d commit d15a6b0
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 0 deletions.
8 changes: 8 additions & 0 deletions test/cpp/api/serialize.cpp
Expand Up @@ -288,6 +288,14 @@ TEST(SerializeTest, MathBits) {
ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
ASSERT_TRUE(actual.allclose(expected));
}

{
// We don't support serializing `ZeroTensor` as it is not public facing yet.
// If in future, `ZeroTensor` serialization is supported, this test should
// start failing!
auto t = torch::_efficientzerotensor({5, 5});
ASSERT_THROWS_WITH(save_and_load(t), "ZeroTensor is not serializable,");
}
}

TEST(SerializeTest, BasicToFile) {
Expand Down
22 changes: 22 additions & 0 deletions test/test_serialization.py
Expand Up @@ -931,6 +931,28 @@ def _save_load_check(t):
t_n_c = torch._neg_view(torch.conj(t))
_save_load_check(t_n_c)

@parametrize('weights_only', (False, True))
def test_serialization_efficient_zerotensor(self, weights_only):
# We don't support serializing `ZeroTensor` as it is not public
# facing yet.
# If in future, `ZeroTensor` serialization is supported, this test
# should start failing!
t = torch._efficientzerotensor((4, 5))

def _save_load_check(t):
with BytesIOContext() as f:
torch.save(t, f)
f.seek(0)
# Unsafe load should work
self.assertEqual(torch.load(f, weights_only=weights_only), t)

# NOTE: `torch.save` fails before we hit the TORCH_CHECK in `getTensoMetadata`
# as nullptr storage is disabled.
err_msg = (r'python bindings to nullptr storage \(e.g., from torch.Tensor._make_wrapper_subclass\)'
' are currently unsafe and thus disabled')
with self.assertRaisesRegex(RuntimeError, err_msg):
_save_load_check(t)

def run(self, *args, **kwargs):
with serialization_method(use_zip=True):
return super(TestSerialization, self).run(*args, **kwargs)
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/serialization/pickler.h
Expand Up @@ -300,6 +300,12 @@ bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls);
// For now, it only takes care of `conj` and `neg` bit.
inline std::unordered_map<std::string, bool> getTensorMetadata(
const at::Tensor& t) {
// We don't support serializing `ZeroTensor` as it is not public
// facing yet.
TORCH_CHECK(
!t._is_zerotensor(),
"ZeroTensor is not serializable,",
" please file an issue if required.");
std::unordered_map<std::string, bool> metadata{};

// Only add meta-data if the value is not default.
Expand Down

0 comments on commit d15a6b0

Please sign in to comment.