Skip to content

Commit

Permalink
Encapsulate serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Apr 2, 2024
1 parent 39d4112 commit 61a4c92
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 35 deletions.
2 changes: 1 addition & 1 deletion distributed/shuffle/_core.py
Expand Up @@ -296,7 +296,7 @@ def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception

def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
def _read_from_disk(self, id: NDIndex) -> Any:
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))

Expand Down
40 changes: 29 additions & 11 deletions distributed/shuffle/_disk.py
Expand Up @@ -6,7 +6,7 @@
import shutil
import threading
from collections.abc import Generator, Iterator
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -123,6 +123,11 @@ class DiskShardsBuffer(ShardsBuffer):
implementation of this scheme.
"""

directory: pathlib.Path
_closed: bool
_use_raw_buffers: bool | None
_directory_lock: ReadWriteLock

def __init__(
self,
directory: str | pathlib.Path,
Expand All @@ -136,6 +141,7 @@ def __init__(
self.directory = pathlib.Path(directory)
self.directory.mkdir(exist_ok=True)
self._closed = False
self._use_raw_buffers = None
self._directory_lock = ReadWriteLock()

@log_errors
Expand All @@ -152,14 +158,23 @@ async def _process(self, id: str, shards: list[Any]) -> None:
future then we should consider simplifying this considerably and
dropping the write into communicate above.
"""
assert shards
if self._use_raw_buffers is None:
self._use_raw_buffers = isinstance(shards[0], list) and isinstance(
shards[0][0], (bytes, bytearray, memoryview)
)
serialize_ctx = (
nullcontext()
if self._use_raw_buffers
else context_meter.meter("serialize", func=thread_time)
)

nbytes_acc = 0

def pickle_and_tally() -> Iterator[bytes | bytearray | memoryview]:
nonlocal nbytes_acc
for shard in shards:
if isinstance(shard, list) and isinstance(
shard[0], (bytes, bytearray, memoryview)
):
if self._use_raw_buffers:
# list[bytes | bytearray | memoryview] for dataframe shuffle
# Shard was pre-serialized before being sent over the network.
nbytes_acc += sum(map(nbytes, shard))
Expand All @@ -173,7 +188,7 @@ def pickle_and_tally() -> Iterator[bytes | bytearray | memoryview]:
with (
self._directory_lock.read(),
context_meter.meter("disk-write"),
context_meter.meter("serialize", func=thread_time),
serialize_ctx,
):
if self._closed:
raise RuntimeError("Already closed")
Expand All @@ -184,7 +199,7 @@ def pickle_and_tally() -> Iterator[bytes | bytearray | memoryview]:
context_meter.digest_metric("disk-write", 1, "count")
context_meter.digest_metric("disk-write", nbytes_acc, "bytes")

def read(self, id: str) -> list[Any]:
def read(self, id: str) -> Any:
"""Read a complete file back into memory"""
self.raise_on_exception()
if not self._inputs_done:
Expand All @@ -211,8 +226,7 @@ def read(self, id: str) -> list[Any]:
else:
raise DataUnavailable(id)

@staticmethod
def _read(path: Path) -> tuple[list[Any], int]:
def _read(self, path: Path) -> tuple[Any, int]:
"""Open a memory-mapped file descriptor to disk, read all metadata, and unpickle
all arrays. This is a fast sequence of short reads interleaved with seeks.
Do not read in memory the actual data; the arrays' buffers will point to the
Expand All @@ -224,10 +238,14 @@ def _read(path: Path) -> tuple[list[Any], int]:
"""
with path.open(mode="r+b") as fh:
buffer = memoryview(mmap.mmap(fh.fileno(), 0))

# The file descriptor has *not* been closed!
shards = list(unpickle_bytestream(buffer))
return shards, buffer.nbytes

assert self._use_raw_buffers is not None
if self._use_raw_buffers:
return buffer, buffer.nbytes
else:
shards = list(unpickle_bytestream(buffer))
return shards, buffer.nbytes

async def close(self) -> None:
await super().close()
Expand Down
72 changes: 71 additions & 1 deletion distributed/shuffle/_pickle.py
Expand Up @@ -2,10 +2,15 @@

import pickle
from collections.abc import Iterator
from typing import Any
from typing import TYPE_CHECKING, Any

from toolz import first

from distributed.protocol.utils import pack_frames_prelude, unpack_frames

if TYPE_CHECKING:
import pandas as pd


def pickle_bytelist(obj: object, prelude: bool = True) -> list[pickle.PickleBuffer]:
"""Variant of :func:`serialize_bytelist`, that doesn't support compression, locally
Expand Down Expand Up @@ -39,3 +44,68 @@ def unpickle_bytestream(b: bytes | bytearray | memoryview) -> Iterator[Any]:
if remainder.nbytes == 0:
break
b = remainder


def pickle_dataframe_shard(
input_part_id: int,
shard: pd.DataFrame,
) -> list[pickle.PickleBuffer]:
"""Optimized pickler for pandas Dataframes. DIscard all unnecessary metadata
(like the columns header).
Parameters:
obj: pandas
"""
return pickle_bytelist(
(input_part_id, shard.index, *shard._mgr.blocks), prelude=False
)


def unpickle_and_concat_dataframe_shards(
b: bytes | bytearray | memoryview, meta: pd.DataFrame
) -> pd.DataFrame:
"""Optimized unpickler for pandas Dataframes.
Parameters
----------
b:
raw buffer, containing the concatenation of the outputs of
:func:`pickle_dataframe_shard`, in arbitrary order
meta:
DataFrame header
Returns
-------
Reconstructed output shard, sorted by input partition ID
**Roundtrip example**
>>> import random
>>> import pandas as pd
>>> from toolz import concat
>>> df = pd.DataFrame(...) # Input partition
>>> meta = df.iloc[:0].copy()
>>> shards = df.iloc[0:10], df.iloc[10:20], ...
>>> frames = [pickle_dataframe_shard(i, shard) for i, shard in enumerate(shards)]
>>> random.shuffle(frames) # Simulate the frames arriving in arbitrary order
>>> blob = bytearray(b"".join(concat(frames))) # Simulate disk roundtrip
>>> df2 = unpickle_and_concat_dataframe_shards(blob, meta)
"""
import pandas as pd
from pandas.core.internals import BlockManager

parts = list(unpickle_bytestream(b))
# [(input_part_id, index, *blocks), ...]
parts = sorted(parts, key=first)
shards = []
for _, idx, *blocks in parts:
axes = [meta.columns, idx]
df = pd.DataFrame._from_mgr( # type: ignore[attr-defined]
BlockManager(blocks, axes, verify_integrity=False), axes
)
shards.append(df)

# Actually load memory-mapped buffers into memory and close the file
# descriptors
return pd.concat(shards, copy=True)
29 changes: 7 additions & 22 deletions distributed/shuffle/_shuffle.py
Expand Up @@ -17,7 +17,6 @@
from pickle import PickleBuffer
from typing import TYPE_CHECKING, Any

from toolz import first
from tornado.ioloop import IOLoop

import dask
Expand All @@ -42,7 +41,10 @@
)
from distributed.shuffle._exceptions import DataUnavailable
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._pickle import pickle_bytelist
from distributed.shuffle._pickle import (
pickle_dataframe_shard,
unpickle_and_concat_dataframe_shards,
)
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import nbytes

Expand Down Expand Up @@ -335,9 +337,7 @@ def split_by_worker(
assert isinstance(output_part_id, int)
if drop_column:
del part[column]
frames = pickle_bytelist(
(input_part_id, part.index, *part._mgr.blocks), prelude=False
)
frames = pickle_dataframe_shard(input_part_id, part)
out[worker_for[output_part_id]].append((output_part_id, frames))

return {k: (input_part_id, v) for k, v in out.items()}
Expand Down Expand Up @@ -516,31 +516,16 @@ def _get_output_partition(
key: Key,
**kwargs: Any,
) -> pd.DataFrame:
import pandas as pd
from pandas.core.internals import BlockManager

meta = self.meta.copy()
if self.drop_column:
meta = self.meta.drop(columns=self.column)

try:
parts = self._read_from_disk((partition_id,))
buffer = self._read_from_disk((partition_id,))
except DataUnavailable:
return meta

# [(input_part_id, index, *blocks), ...]
parts = sorted(parts, key=first)
shards = []
for _, idx, *blocks in parts:
axes = [meta.columns, idx]
df = pd.DataFrame._from_mgr( # type: ignore[attr-defined]
BlockManager(blocks, axes, verify_integrity=False), axes
)
shards.append(df)

# Actually load memory-mapped buffers into memory and close the file
# descriptors
return pd.concat(shards, copy=True)
return unpickle_and_concat_dataframe_shards(buffer, meta)

def _get_assigned_worker(self, id: int) -> str:
return self.worker_for[id]
Expand Down

0 comments on commit 61a4c92

Please sign in to comment.