Skip to content

Commit

Permalink
Ensure inproc properly emulates serialization protocol (#8622)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Apr 18, 2024
1 parent 3f13a2d commit f621c65
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 16 deletions.
5 changes: 2 additions & 3 deletions distributed/comm/inproc.py
Expand Up @@ -13,7 +13,7 @@

from distributed.comm.core import BaseListener, Comm, CommClosedError, Connector
from distributed.comm.registry import Backend, backends
from distributed.protocol import nested_deserialize
from distributed.protocol.serialize import _nested_deserialize
from distributed.utils import get_ip, is_python_shutting_down

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -218,8 +218,7 @@ async def read(self, deserializers="ignored"):
self._finalizer.detach()
raise CommClosedError()

if self.deserialize:
msg = nested_deserialize(msg)
msg = _nested_deserialize(msg, self.deserialize)
return msg

async def write(self, msg, serializers=None, on_error=None):
Expand Down
2 changes: 2 additions & 0 deletions distributed/protocol/__init__.py
Expand Up @@ -6,8 +6,10 @@
from distributed.protocol.core import decompress, dumps, loads, maybe_compress, msgpack
from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
from distributed.protocol.serialize import (
Pickled,
Serialize,
Serialized,
ToPickle,
dask_deserialize,
dask_serialize,
deserialize,
Expand Down
27 changes: 21 additions & 6 deletions distributed/protocol/serialize.py
Expand Up @@ -3,6 +3,7 @@
import codecs
import importlib
import traceback
import warnings
from array import array
from enum import Enum
from functools import partial
Expand Down Expand Up @@ -621,6 +622,14 @@ def __ne__(self, other):


def nested_deserialize(x):
warnings.warn(
"nested_deserialize is deprecated and will be removed in a future release.",
DeprecationWarning,
)
return _nested_deserialize(x, emulate_deserialize=True)


def _nested_deserialize(x, emulate_deserialize=True):
"""
Replace all Serialize and Serialized values nested in *x*
with the original values. Returns a copy of *x*.
Expand All @@ -637,21 +646,27 @@ def replace_inner(x):
typ = type(v)
if typ is dict or typ is list:
x[k] = replace_inner(v)
elif typ is Serialize:
if emulate_deserialize:
if typ is Serialize:
x[k] = v.data
elif typ is Serialized:
x[k] = deserialize(v.header, v.frames)
if typ is ToPickle:
x[k] = v.data
elif typ is Serialized:
x[k] = deserialize(v.header, v.frames)

elif type(x) is list:
x = list(x)
for k, v in enumerate(x):
typ = type(v)
if typ is dict or typ is list:
x[k] = replace_inner(v)
elif typ is Serialize:
if emulate_deserialize:
if typ is Serialize:
x[k] = v.data
elif typ is Serialized:
x[k] = deserialize(v.header, v.frames)
if typ is ToPickle:
x[k] = v.data
elif typ is Serialized:
x[k] = deserialize(v.header, v.frames)

return x

Expand Down
17 changes: 15 additions & 2 deletions distributed/protocol/tests/test_serialize.py
Expand Up @@ -20,12 +20,12 @@
from distributed.protocol import (
Serialize,
Serialized,
ToPickle,
dask_serialize,
deserialize,
deserialize_bytes,
dumps,
loads,
nested_deserialize,
register_serialization,
register_serialization_family,
serialize,
Expand All @@ -35,6 +35,7 @@
)
from distributed.protocol.serialize import (
_is_msgpack_serializable,
_nested_deserialize,
check_dask_serializable,
)
from distributed.utils import ensure_memoryview, nbytes
Expand Down Expand Up @@ -166,12 +167,24 @@ def test_nested_deserialize():
"x": [to_serialize(123), to_serialize(456), 789],
"y": {"a": ["abc", Serialized(*serialize("def"))], "b": b"ghi"},
}

x_orig = copy.deepcopy(x)
assert _nested_deserialize(x, emulate_deserialize=False) == x_orig

assert x == x_orig # x wasn't mutated
x["topickle"] = ToPickle(1)
x["topickle_nested"] = [1, ToPickle(2)]
x_orig = copy.deepcopy(x)
assert (out := _nested_deserialize(x, emulate_deserialize=False)) != x_orig
assert out["topickle"] == 1
assert out["topickle_nested"] == [1, 2]

assert nested_deserialize(x) == {
assert _nested_deserialize(x) == {
"op": "update",
"x": [123, 456, 789],
"y": {"a": ["abc", "def"], "b": b"ghi"},
"topickle": 1,
"topickle_nested": [1, 2],
}
assert x == x_orig # x wasn't mutated

Expand Down
3 changes: 0 additions & 3 deletions distributed/scheduler.py
Expand Up @@ -4676,9 +4676,6 @@ async def update_graph(
annotations: dict | None = None,
stimulus_id: str | None = None,
) -> None:
# FIXME: Apparently empty dicts arrive as a ToPickle object
if isinstance(annotations, ToPickle):
annotations = annotations.data # type: ignore[unreachable]
start = time()
try:
try:
Expand Down
4 changes: 2 additions & 2 deletions distributed/shuffle/tests/utils.py
Expand Up @@ -22,15 +22,15 @@ def __init__(self, shuffle: ShuffleRun):

def __getattr__(self, key):
async def _(**kwargs):
from distributed.protocol.serialize import nested_deserialize
from distributed.protocol.serialize import _nested_deserialize

method_name = key.replace("shuffle_", "")
kwargs.pop("shuffle_id", None)
kwargs.pop("run_id", None)
# TODO: This is a bit awkward. At some point the arguments are
# already getting wrapped with a `Serialize`. We only want to unwrap
# here.
kwargs = nested_deserialize(kwargs)
kwargs = _nested_deserialize(kwargs)
meth = getattr(self.shuffle, method_name)
return await meth(**kwargs)

Expand Down

0 comments on commit f621c65

Please sign in to comment.