Skip to content

Commit

Permalink
Encapsulate serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Apr 2, 2024
1 parent 39d4112 commit 676b308
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 23 deletions.
78 changes: 76 additions & 2 deletions distributed/shuffle/_pickle.py
@@ -1,11 +1,16 @@
from __future__ import annotations

import pickle
from collections.abc import Iterator
from typing import Any
from collections.abc import Iterable, Iterator
from typing import TYPE_CHECKING, Any

from toolz import first

from distributed.protocol.utils import pack_frames_prelude, unpack_frames

if TYPE_CHECKING:
import pandas as pd


def pickle_bytelist(obj: object, prelude: bool = True) -> list[pickle.PickleBuffer]:
"""Variant of :func:`serialize_bytelist`, that doesn't support compression, locally
Expand Down Expand Up @@ -39,3 +44,72 @@ def unpickle_bytestream(b: bytes | bytearray | memoryview) -> Iterator[Any]:
if remainder.nbytes == 0:
break
b = remainder


def pickle_dataframe_shard(
input_part_id: int,
shard: pd.DataFrame,
) -> list[pickle.PickleBuffer]:
"""Optimized pickler for pandas Dataframes. DIscard all unnecessary metadata
(like the columns header).
Parameters:
obj: pandas
"""
return pickle_bytelist(
(input_part_id, shard.index, *shard._mgr.blocks), prelude=False
)


def unpickle_and_concat_dataframe_shards(
parts: Iterable[Any], meta: pd.DataFrame
) -> pd.DataFrame:
"""Optimized unpickler for pandas Dataframes.
Parameters
----------
parts:
output of ``unpickle_bytestream(b)``, where b is the memory-mapped blob of
pickled data which is the concatenation of the outputs of
:func:`pickle_dataframe_shard` in arbitrary order
meta:
DataFrame header
Returns
-------
Reconstructed output shard, sorted by input partition ID
**Roundtrip example**
.. code-block:: python
import random
import pandas as pd
df = pd.DataFrame(...) # Input partition
meta = df.iloc[:0]
shards = df.iloc[0:10], df.iloc[10:20], ...
frames = [pickle_dataframe_shard(i, shard) for i, shard in enumerate(shards)]
random.shuffle(frames) # Simulate the frames arriving in arbitrary order
frames = [f for fs in frames for f in fs] # Flatten
blob = bytearray(b"".join(frames)) # Simulate disk roundtrip
parts = unpickle_bytestream(blob)
df2 = unpickle_and_concat_dataframe_shards(parts, meta)
"""
import pandas as pd
from pandas.core.internals import BlockManager

# [(input_part_id, index, *blocks), ...]
parts = sorted(parts, key=first)
shards = []
for _, idx, *blocks in parts:
axes = [meta.columns, idx]
df = pd.DataFrame._from_mgr( # type: ignore[attr-defined]
BlockManager(blocks, axes, verify_integrity=False), axes
)
shards.append(df)

# Actually load memory-mapped buffers into memory and close the file
# descriptors
return pd.concat(shards, copy=True)
27 changes: 6 additions & 21 deletions distributed/shuffle/_shuffle.py
Expand Up @@ -17,7 +17,6 @@
from pickle import PickleBuffer
from typing import TYPE_CHECKING, Any

from toolz import first
from tornado.ioloop import IOLoop

import dask
Expand All @@ -42,7 +41,10 @@
)
from distributed.shuffle._exceptions import DataUnavailable
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._pickle import pickle_bytelist
from distributed.shuffle._pickle import (
pickle_dataframe_shard,
unpickle_and_concat_dataframe_shards,
)
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import nbytes

Expand Down Expand Up @@ -335,9 +337,7 @@ def split_by_worker(
assert isinstance(output_part_id, int)
if drop_column:
del part[column]
frames = pickle_bytelist(
(input_part_id, part.index, *part._mgr.blocks), prelude=False
)
frames = pickle_dataframe_shard(input_part_id, part)
out[worker_for[output_part_id]].append((output_part_id, frames))

return {k: (input_part_id, v) for k, v in out.items()}
Expand Down Expand Up @@ -516,9 +516,6 @@ def _get_output_partition(
key: Key,
**kwargs: Any,
) -> pd.DataFrame:
import pandas as pd
from pandas.core.internals import BlockManager

meta = self.meta.copy()
if self.drop_column:
meta = self.meta.drop(columns=self.column)
Expand All @@ -528,19 +525,7 @@ def _get_output_partition(
except DataUnavailable:
return meta

# [(input_part_id, index, *blocks), ...]
parts = sorted(parts, key=first)
shards = []
for _, idx, *blocks in parts:
axes = [meta.columns, idx]
df = pd.DataFrame._from_mgr( # type: ignore[attr-defined]
BlockManager(blocks, axes, verify_integrity=False), axes
)
shards.append(df)

# Actually load memory-mapped buffers into memory and close the file
# descriptors
return pd.concat(shards, copy=True)
return unpickle_and_concat_dataframe_shards(parts, meta)

def _get_assigned_worker(self, id: int) -> str:
return self.worker_for[id]
Expand Down

0 comments on commit 676b308

Please sign in to comment.