Skip to content

Commit

Permalink
Add weights_only option to torch.load (pytorch#87443)
Browse files Browse the repository at this point in the history
* Tweak several test serialization to store models state_dict (pytorch#87143)

Namely, change:
- `test_meta_serialization`
- `test_serialization_2gb_file`
- `test_pathlike_serialization`
Pull Request resolved: pytorch#87143
Approved by: https://github.com/ezyang

(cherry picked from commit 4a533f1)

* Add `weights_only` option to `torch.load` (pytorch#86812)

This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as  `torch.Tensor` and `torch.Storage` variants.

Defaults `weights_only` is set to `False`,  but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.

To some extent, addresses pytorch#52596
Pull Request resolved: pytorch#86812
Approved by: https://github.com/ezyang

(cherry picked from commit 961ebca)
  • Loading branch information
malfet committed Oct 21, 2022
1 parent 59686b4 commit 0c0df0b
Show file tree
Hide file tree
Showing 3 changed files with 395 additions and 33 deletions.
99 changes: 71 additions & 28 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,17 @@ def test(name_or_buffer):

test(io.BytesIO())

def test_serialization(self):
def _test_serialization(self, weights_only):
# Test serialization with a real file
b = self._test_serialization_data()
with tempfile.NamedTemporaryFile() as f:
torch.save(b, f)
f.seek(0)
c = torch.load(f)
c = torch.load(f, weights_only=weights_only)
self._test_serialization_assert(b, c)
with TemporaryFileName() as fname:
torch.save(b, fname)
c = torch.load(fname)
c = torch.load(fname, weights_only=weights_only)
self._test_serialization_assert(b, c)
# test non-ascii encoding of bytes arrays/strings
# The following bytes are produced by serializing
Expand All @@ -180,12 +180,18 @@ def test_serialization(self):
buf = io.BytesIO(serialized)
utf8_bytes = b'\xc5\xbc\xc4\x85\xc4\x85\xc3\xb3\xc5\xbc\xc4\x85\xc5\xbc'
utf8_str = utf8_bytes.decode('utf-8')
loaded_utf8 = torch.load(buf, encoding='utf-8')
loaded_utf8 = torch.load(buf, weights_only=weights_only, encoding='utf-8')
self.assertEqual(loaded_utf8, [utf8_str, torch.zeros(1, dtype=torch.float), 2])
buf.seek(0)
loaded_bytes = torch.load(buf, encoding='bytes')
loaded_bytes = torch.load(buf, weights_only=weights_only, encoding='bytes')
self.assertEqual(loaded_bytes, [utf8_bytes, torch.zeros(1, dtype=torch.float), 2])

def test_serialization(self):
self._test_serialization(False)

def test_serialization_safe(self):
self._test_serialization(True)

def test_serialization_filelike(self):
# Test serialization (load and save) with a filelike object
b = self._test_serialization_data()
Expand Down Expand Up @@ -279,19 +285,25 @@ def test_serialization_offset_gzip(self):
self.assertTrue(torch.equal(a, b))
self.assertEqual(i, j)

def test_serialization_sparse(self):
def _test_serialization_sparse(self, weights_only):
def _test_serialization(conversion):
x = torch.zeros(3, 3)
x[1][1] = 1
x = conversion(x)
with tempfile.NamedTemporaryFile() as f:
torch.save({"tensor": x}, f)
f.seek(0)
y = torch.load(f)
y = torch.load(f, weights_only=weights_only)
self.assertEqual(x, y["tensor"])
_test_serialization(lambda x: x.to_sparse())
_test_serialization(lambda x: x.to_sparse_csr())

def test_serialization_sparse(self):
self._test_serialization(False)

def test_serialization_sparse_safe(self):
self._test_serialization(True)

def test_serialization_sparse_invalid(self):
x = torch.zeros(3, 3)
x[1][1] = 1
Expand Down Expand Up @@ -358,13 +370,13 @@ def test_serialize_device(self):
device_copied = copy.deepcopy(device)
self.assertEqual(device, device_copied)

def test_serialization_backwards_compat(self):
def _test_serialization_backwards_compat(self, weights_only):
a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)]
b = [a[i % 2] for i in range(4)]
b += [a[0].storage()]
b += [a[0].reshape(-1)[1:4].clone().storage()]
path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt')
c = torch.load(path)
c = torch.load(path, weights_only=weights_only)
self.assertEqual(b, c, atol=0, rtol=0)
self.assertTrue(isinstance(c[0], torch.FloatTensor))
self.assertTrue(isinstance(c[1], torch.FloatTensor))
Expand Down Expand Up @@ -403,12 +415,17 @@ def __reduce__(self):
old_x = old_cls(x)
torch.save(old_x, f)
f.seek(0)
load_x = torch.load(f)
load_x = torch.load(f, weights_only=weights_only)
self.assertEqual(x.storage(), load_x.storage())
self.assertEqual(x.storage_offset(), load_x.storage_offset())
self.assertEqual(x.size(), load_x.size())
self.assertEqual(x.stride(), load_x.stride())

def test_serialization_backwards_compat(self):
self._test_serialization_backwards_compat(False)

def test_serialization_backwards_compat_safe(self):
self._test_serialization_backwards_compat(True)

def test_serialization_save_warnings(self):
with warnings.catch_warnings(record=True) as warns:
Expand Down Expand Up @@ -680,25 +697,31 @@ def wrapper(*args, **kwargs):
def __exit__(self, *args, **kwargs):
torch.save = self.torch_save

@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
class TestBothSerialization(TestCase):
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
def test_serialization_new_format_old_format_compat(self, device):
def _test_serialization_new_format_old_format_compat(self, device, weights_only):
x = [torch.ones(200, 200, device=device) for i in range(30)]

def test(f_new, f_old):
torch.save(x, f_new, _use_new_zipfile_serialization=True)
f_new.seek(0)
x_new_load = torch.load(f_new)
x_new_load = torch.load(f_new, weights_only=weights_only)
self.assertEqual(x, x_new_load)

torch.save(x, f_old, _use_new_zipfile_serialization=False)
f_old.seek(0)
x_old_load = torch.load(f_old)
x_old_load = torch.load(f_old, weights_only=weights_only)
self.assertEqual(x_old_load, x_new_load)

with tempfile.NamedTemporaryFile() as f_new, tempfile.NamedTemporaryFile() as f_old:
test(f_new, f_old)

def test_serialization_new_format_old_format_compat(self, device):
self._test_serialization_new_format_old_format_compat(device, False)

def test_serialization_new_format_old_format_compat_safe(self, device):
self._test_serialization_new_format_old_format_compat(device, True)


class TestOldSerialization(TestCase, SerializationMixin):
# unique_key is necessary because on Python 2.7, if a warning passed to
Expand All @@ -721,7 +744,7 @@ def import_module(name, filename):
module = import_module(tmpmodule_name, fname)
torch.save(module.Net(), checkpoint)

# First check that the checkpoint can be loaded without warnings
# First check that the checkpoint can be loaded without warning about unsafe loads
checkpoint.seek(0)
with warnings.catch_warnings(record=True) as w:
loaded = torch.load(checkpoint)
Expand Down Expand Up @@ -771,7 +794,8 @@ def test_serialization_offset(self):
self.assertEqual(i, i_loaded)
self.assertEqual(j, j_loaded)

def test_serialization_offset_filelike(self):
@parametrize('weights_only', (True, False))
def test_serialization_offset_filelike(self, weights_only):
a = torch.randn(5, 5)
b = torch.randn(1024, 1024, 512, dtype=torch.float32)
i, j = 41, 43
Expand All @@ -783,9 +807,9 @@ def test_serialization_offset_filelike(self):
self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
f.seek(0)
i_loaded = pickle.load(f)
a_loaded = torch.load(f)
a_loaded = torch.load(f, weights_only=weights_only)
j_loaded = pickle.load(f)
b_loaded = torch.load(f)
b_loaded = torch.load(f, weights_only=weights_only)
self.assertTrue(torch.equal(a, a_loaded))
self.assertTrue(torch.equal(b, b_loaded))
self.assertEqual(i, i_loaded)
Expand All @@ -797,7 +821,8 @@ def run(self, *args, **kwargs):


class TestSerialization(TestCase, SerializationMixin):
def test_serialization_zipfile(self):
@parametrize('weights_only', (True, False))
def test_serialization_zipfile(self, weights_only):
data = self._test_serialization_data()

def test(name_or_buffer):
Expand All @@ -806,7 +831,7 @@ def test(name_or_buffer):
if hasattr(name_or_buffer, 'seek'):
name_or_buffer.seek(0)

result = torch.load(name_or_buffer)
result = torch.load(name_or_buffer, weights_only=weights_only)
self.assertEqual(result, data)

with tempfile.NamedTemporaryFile() as f:
Expand All @@ -828,28 +853,44 @@ def test_serialization_2gb_file(self):
big_model = torch.nn.Conv2d(20000, 3200, kernel_size=3)

with BytesIOContext() as f:
torch.save(big_model, f)
torch.save(big_model.state_dict(), f)
f.seek(0)
state = torch.load(f)

def test_pathlike_serialization(self):
@parametrize('weights_only', (True, False))
def test_pathlike_serialization(self, weights_only):
model = torch.nn.Conv2d(20, 3200, kernel_size=3)

with TemporaryFileName() as fname:
path = pathlib.Path(fname)
torch.save(model, path)
torch.load(path)
torch.save(model.state_dict(), path)
torch.load(path, weights_only=weights_only)

def test_meta_serialization(self):
@parametrize('weights_only', (True, False))
def test_meta_serialization(self, weights_only):
big_model = torch.nn.Conv2d(20000, 320000, kernel_size=3, device='meta')

with BytesIOContext() as f:
torch.save(big_model, f)
torch.save(big_model.state_dict(), f)
f.seek(0)
state = torch.load(f)
state = torch.load(f, weights_only=weights_only)

self.assertEqual(state['weight'].size(), big_model.weight.size())

self.assertEqual(state.weight.size(), big_model.weight.size())
def test_weights_only_assert(self):
class HelloWorld:
def __reduce__(self):
return (print, ("Hello World!",))

with BytesIOContext() as f:
torch.save(HelloWorld(), f)
f.seek(0)
# Unsafe load should work
self.assertIsNone(torch.load(f, weights_only=False))
f.seek(0)
# Safe load should assert
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported class"):
torch.load(f, weights_only=True)

def run(self, *args, **kwargs):
with serialization_method(use_zip=True):
Expand Down Expand Up @@ -983,6 +1024,8 @@ def test_empty_class_serialization(self):

instantiate_device_type_tests(TestBothSerialization, globals())
instantiate_parametrized_tests(TestSubclassSerialization)
instantiate_parametrized_tests(TestOldSerialization)
instantiate_parametrized_tests(TestSerialization)

if __name__ == '__main__':
run_tests()

0 comments on commit 0c0df0b

Please sign in to comment.