From e127f773043cc71c46f56f65c354adbff2ed38c3 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Fri, 10 Sep 2021 19:44:35 -0700 Subject: [PATCH] Test pickle protocols 4 & 5 Covers testing of both pickle protocol 4 & 5 in the pickle serialization test suite. This should help catch assumptions predicated on protocol 5 being in use that are not always the case (as happened with the test failure in cloudpickle recently). --- distributed/protocol/tests/test_pickle.py | 71 ++++++++++++++--------- 1 file changed, 45 insertions(+), 26 deletions(-) diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index a4ab843564..436eb78ad5 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -29,40 +29,46 @@ def __reduce_ex__(self, protocol): return MemoryviewHolder, (self.mv.tobytes(),) -def test_pickle_data(): +@pytest.mark.parametrize("protocol", {4, HIGHEST_PROTOCOL}) +def test_pickle_data(protocol): + context = {"pickle-protocol": protocol} + data = [1, b"123", "123", [123], {}, set()] for d in data: - assert loads(dumps(d)) == d - assert deserialize(*serialize(d, serializers=("pickle",))) == d + assert loads(dumps(d, protocol=protocol)) == d + assert deserialize(*serialize(d, serializers=("pickle",), context=context)) == d + +@pytest.mark.parametrize("protocol", {4, HIGHEST_PROTOCOL}) +def test_pickle_out_of_band(protocol): + context = {"pickle-protocol": protocol} -def test_pickle_out_of_band(): mv = memoryview(b"123") mvh = MemoryviewHolder(mv) - if HIGHEST_PROTOCOL >= 5: + if protocol >= 5: l = [] - d = dumps(mvh, buffer_callback=l.append) + d = dumps(mvh, protocol=protocol, buffer_callback=l.append) mvh2 = loads(d, buffers=l) assert len(l) == 1 assert isinstance(l[0], pickle.PickleBuffer) assert memoryview(l[0]) == mv else: - mvh2 = loads(dumps(mvh)) + mvh2 = loads(dumps(mvh, protocol=protocol)) assert isinstance(mvh2, MemoryviewHolder) assert isinstance(mvh2.mv, memoryview) assert mvh2.mv == mv - h, f = serialize(mvh, serializers=("pickle",)) + h, f = serialize(mvh, serializers=("pickle",), context=context) mvh3 = deserialize(h, f) assert isinstance(mvh3, MemoryviewHolder) assert isinstance(mvh3.mv, memoryview) assert mvh3.mv == mv - if HIGHEST_PROTOCOL >= 5: + if protocol >= 5: assert len(f) == 2 assert isinstance(f[0], bytes) assert isinstance(f[1], memoryview) @@ -72,15 +78,18 @@ def test_pickle_out_of_band(): assert isinstance(f[0], bytes) -def test_pickle_empty(): +@pytest.mark.parametrize("protocol", {4, HIGHEST_PROTOCOL}) +def test_pickle_empty(protocol): + context = {"pickle-protocol": protocol} + x = MemoryviewHolder(bytearray()) # Empty view - header, frames = serialize(x, serializers=("pickle",)) + header, frames = serialize(x, serializers=("pickle",), context=context) assert header["serializer"] == "pickle" assert len(frames) >= 1 assert isinstance(frames[0], bytes) - if HIGHEST_PROTOCOL >= 5: + if protocol >= 5: assert len(frames) == 2 assert len(header["writeable"]) == 1 @@ -98,25 +107,32 @@ def test_pickle_empty(): assert y.mv.readonly -def test_pickle_numpy(): +@pytest.mark.parametrize("protocol", {4, HIGHEST_PROTOCOL}) +def test_pickle_numpy(protocol): np = pytest.importorskip("numpy") + context = {"pickle-protocol": protocol} + x = np.ones(5) - assert (loads(dumps(x)) == x).all() - assert (deserialize(*serialize(x, serializers=("pickle",))) == x).all() + assert (loads(dumps(x, protocol=protocol)) == x).all() + assert ( + deserialize(*serialize(x, serializers=("pickle",), context=context)) == x + ).all() x = np.ones(5000) - assert (loads(dumps(x)) == x).all() - assert (deserialize(*serialize(x, serializers=("pickle",))) == x).all() + assert (loads(dumps(x, protocol=protocol)) == x).all() + assert ( + deserialize(*serialize(x, serializers=("pickle",), context=context)) == x + ).all() x = np.array([np.arange(3), np.arange(4, 6)], dtype=object) - x2 = loads(dumps(x)) + x2 = loads(dumps(x, protocol=protocol)) assert x.shape == x2.shape assert x.dtype == x2.dtype assert x.strides == x2.strides for e_x, e_x2 in zip(x.flat, x2.flat): np.testing.assert_equal(e_x, e_x2) - h, f = serialize(x, serializers=("pickle",)) - if HIGHEST_PROTOCOL >= 5: + h, f = serialize(x, serializers=("pickle",), context=context) + if protocol >= 5: assert len(f) == 3 else: assert len(f) == 1 @@ -127,24 +143,27 @@ def test_pickle_numpy(): for e_x, e_x3 in zip(x.flat, x3.flat): np.testing.assert_equal(e_x, e_x3) - if HIGHEST_PROTOCOL >= 5: + if protocol >= 5: x = np.ones(5000) l = [] - d = dumps(x, buffer_callback=l.append) + d = dumps(x, protocol=protocol, buffer_callback=l.append) assert len(l) == 1 assert isinstance(l[0], pickle.PickleBuffer) assert memoryview(l[0]) == memoryview(x) assert (loads(d, buffers=l) == x).all() - h, f = serialize(x, serializers=("pickle",)) + h, f = serialize(x, serializers=("pickle",), context=context) assert len(f) == 2 assert isinstance(f[0], bytes) assert isinstance(f[1], memoryview) assert (deserialize(h, f) == x).all() -def test_pickle_functions(): +@pytest.mark.parametrize("protocol", {4, HIGHEST_PROTOCOL}) +def test_pickle_functions(protocol): + context = {"pickle-protocol": protocol} + def make_closure(): value = 1 @@ -161,11 +180,11 @@ def funcs(): for func in funcs(): wr = weakref.ref(func) - func2 = loads(dumps(func)) + func2 = loads(dumps(func, protocol=protocol)) wr2 = weakref.ref(func2) assert func2(1) == func(1) - func3 = deserialize(*serialize(func, serializers=("pickle",))) + func3 = deserialize(*serialize(func, serializers=("pickle",), context=context)) wr3 = weakref.ref(func3) assert func3(1) == func(1)