Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DNM] P2P shuffle without PyArrow #8606

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions distributed/protocol/serialize.py
Expand Up @@ -839,6 +839,17 @@
return out


@dask_serialize.register(PickleBuffer)
def _serialize_picklebuffer(obj):
return _serialize_memoryview(obj.raw())

Check warning on line 844 in distributed/protocol/serialize.py

View check run for this annotation

Codecov / codecov/patch

distributed/protocol/serialize.py#L844

Added line #L844 was not covered by tests


@dask_deserialize.register(PickleBuffer)
def _deserialize_picklebuffer(header, frames):
out = _deserialize_memoryview(header, frames)
return PickleBuffer(out)

Check warning on line 850 in distributed/protocol/serialize.py

View check run for this annotation

Codecov / codecov/patch

distributed/protocol/serialize.py#L849-L850

Added lines #L849 - L850 were not covered by tests


#########################
# Descend into __dict__ #
#########################
Expand Down
2 changes: 0 additions & 2 deletions distributed/shuffle/__init__.py
@@ -1,14 +1,12 @@
from __future__ import annotations

from distributed.shuffle._arrow import check_minimal_arrow_version
from distributed.shuffle._merge import HashJoinP2PLayer, hash_join_p2p
from distributed.shuffle._rechunk import rechunk_p2p
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin
from distributed.shuffle._shuffle import P2PShuffleLayer, rearrange_by_column_p2p
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin

__all__ = [
"check_minimal_arrow_version",
"hash_join_p2p",
"HashJoinP2PLayer",
"P2PShuffleLayer",
Expand Down
201 changes: 0 additions & 201 deletions distributed/shuffle/_arrow.py

This file was deleted.

17 changes: 4 additions & 13 deletions distributed/shuffle/_core.py
Expand Up @@ -19,7 +19,6 @@
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast

from tornado.ioloop import IOLoop
Expand Down Expand Up @@ -116,11 +115,10 @@ def __init__(
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
self._disk_buffer = MemoryShardsBuffer()

with self._capture_metrics("background-comms"):
self._comm_buffer = CommShardsBuffer(
Expand Down Expand Up @@ -216,7 +214,7 @@ async def send(
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
shards_or_bytes: list | bytes = pickle.dumps(shards, protocol=5)
else:
shards_or_bytes = shards

Expand Down Expand Up @@ -298,7 +296,7 @@ def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception

def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
def _read_from_disk(self, id: NDIndex) -> Any:
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))

Expand Down Expand Up @@ -335,6 +333,7 @@ def add_partition(
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
context_meter.digest_metric("p2p-partitions", 1, "count")
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
Expand Down Expand Up @@ -372,14 +371,6 @@ def _get_output_partition(
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""

@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""

@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""


def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
Expand Down