Skip to content

Commit

Permalink
Test pickle protocols 4 & 5 (#5313)
Browse files Browse the repository at this point in the history
Covers testing of both pickle protocol 4 & 5 in the pickle serialization
test suite. This should help catch assumptions predicated on protocol 5
being in use that are not always the case (as happened with the test
failure in cloudpickle recently).
  • Loading branch information
jakirkham committed Sep 11, 2021
1 parent 0774365 commit 518024a
Showing 1 changed file with 45 additions and 26 deletions.
71 changes: 45 additions & 26 deletions distributed/protocol/tests/test_pickle.py
Expand Up @@ -29,40 +29,46 @@ def __reduce_ex__(self, protocol):
return MemoryviewHolder, (self.mv.tobytes(),)


def test_pickle_data():
@pytest.mark.parametrize("protocol", {4, HIGHEST_PROTOCOL})
def test_pickle_data(protocol):
context = {"pickle-protocol": protocol}

data = [1, b"123", "123", [123], {}, set()]
for d in data:
assert loads(dumps(d)) == d
assert deserialize(*serialize(d, serializers=("pickle",))) == d
assert loads(dumps(d, protocol=protocol)) == d
assert deserialize(*serialize(d, serializers=("pickle",), context=context)) == d


@pytest.mark.parametrize("protocol", {4, HIGHEST_PROTOCOL})
def test_pickle_out_of_band(protocol):
context = {"pickle-protocol": protocol}

def test_pickle_out_of_band():
mv = memoryview(b"123")
mvh = MemoryviewHolder(mv)

if HIGHEST_PROTOCOL >= 5:
if protocol >= 5:
l = []
d = dumps(mvh, buffer_callback=l.append)
d = dumps(mvh, protocol=protocol, buffer_callback=l.append)
mvh2 = loads(d, buffers=l)

assert len(l) == 1
assert isinstance(l[0], pickle.PickleBuffer)
assert memoryview(l[0]) == mv
else:
mvh2 = loads(dumps(mvh))
mvh2 = loads(dumps(mvh, protocol=protocol))

assert isinstance(mvh2, MemoryviewHolder)
assert isinstance(mvh2.mv, memoryview)
assert mvh2.mv == mv

h, f = serialize(mvh, serializers=("pickle",))
h, f = serialize(mvh, serializers=("pickle",), context=context)
mvh3 = deserialize(h, f)

assert isinstance(mvh3, MemoryviewHolder)
assert isinstance(mvh3.mv, memoryview)
assert mvh3.mv == mv

if HIGHEST_PROTOCOL >= 5:
if protocol >= 5:
assert len(f) == 2
assert isinstance(f[0], bytes)
assert isinstance(f[1], memoryview)
Expand All @@ -72,15 +78,18 @@ def test_pickle_out_of_band():
assert isinstance(f[0], bytes)


def test_pickle_empty():
@pytest.mark.parametrize("protocol", {4, HIGHEST_PROTOCOL})
def test_pickle_empty(protocol):
context = {"pickle-protocol": protocol}

x = MemoryviewHolder(bytearray()) # Empty view
header, frames = serialize(x, serializers=("pickle",))
header, frames = serialize(x, serializers=("pickle",), context=context)

assert header["serializer"] == "pickle"
assert len(frames) >= 1
assert isinstance(frames[0], bytes)

if HIGHEST_PROTOCOL >= 5:
if protocol >= 5:
assert len(frames) == 2
assert len(header["writeable"]) == 1

Expand All @@ -98,25 +107,32 @@ def test_pickle_empty():
assert y.mv.readonly


def test_pickle_numpy():
@pytest.mark.parametrize("protocol", {4, HIGHEST_PROTOCOL})
def test_pickle_numpy(protocol):
np = pytest.importorskip("numpy")
context = {"pickle-protocol": protocol}

x = np.ones(5)
assert (loads(dumps(x)) == x).all()
assert (deserialize(*serialize(x, serializers=("pickle",))) == x).all()
assert (loads(dumps(x, protocol=protocol)) == x).all()
assert (
deserialize(*serialize(x, serializers=("pickle",), context=context)) == x
).all()

x = np.ones(5000)
assert (loads(dumps(x)) == x).all()
assert (deserialize(*serialize(x, serializers=("pickle",))) == x).all()
assert (loads(dumps(x, protocol=protocol)) == x).all()
assert (
deserialize(*serialize(x, serializers=("pickle",), context=context)) == x
).all()

x = np.array([np.arange(3), np.arange(4, 6)], dtype=object)
x2 = loads(dumps(x))
x2 = loads(dumps(x, protocol=protocol))
assert x.shape == x2.shape
assert x.dtype == x2.dtype
assert x.strides == x2.strides
for e_x, e_x2 in zip(x.flat, x2.flat):
np.testing.assert_equal(e_x, e_x2)
h, f = serialize(x, serializers=("pickle",))
if HIGHEST_PROTOCOL >= 5:
h, f = serialize(x, serializers=("pickle",), context=context)
if protocol >= 5:
assert len(f) == 3
else:
assert len(f) == 1
Expand All @@ -127,24 +143,27 @@ def test_pickle_numpy():
for e_x, e_x3 in zip(x.flat, x3.flat):
np.testing.assert_equal(e_x, e_x3)

if HIGHEST_PROTOCOL >= 5:
if protocol >= 5:
x = np.ones(5000)

l = []
d = dumps(x, buffer_callback=l.append)
d = dumps(x, protocol=protocol, buffer_callback=l.append)
assert len(l) == 1
assert isinstance(l[0], pickle.PickleBuffer)
assert memoryview(l[0]) == memoryview(x)
assert (loads(d, buffers=l) == x).all()

h, f = serialize(x, serializers=("pickle",))
h, f = serialize(x, serializers=("pickle",), context=context)
assert len(f) == 2
assert isinstance(f[0], bytes)
assert isinstance(f[1], memoryview)
assert (deserialize(h, f) == x).all()


def test_pickle_functions():
@pytest.mark.parametrize("protocol", {4, HIGHEST_PROTOCOL})
def test_pickle_functions(protocol):
context = {"pickle-protocol": protocol}

def make_closure():
value = 1

Expand All @@ -161,11 +180,11 @@ def funcs():
for func in funcs():
wr = weakref.ref(func)

func2 = loads(dumps(func))
func2 = loads(dumps(func, protocol=protocol))
wr2 = weakref.ref(func2)
assert func2(1) == func(1)

func3 = deserialize(*serialize(func, serializers=("pickle",)))
func3 = deserialize(*serialize(func, serializers=("pickle",), context=context))
wr3 = weakref.ref(func3)
assert func3(1) == func(1)

Expand Down

0 comments on commit 518024a

Please sign in to comment.