Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC] Distributed checkpointing API for dask collections #8483

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
247 changes: 247 additions & 0 deletions distributed/checkpoint.py
@@ -0,0 +1,247 @@
from __future__ import annotations

import contextlib
import glob
import os
import pickle
from collections import defaultdict
from importlib import import_module

import dask.dataframe as dd
from dask.blockwise import BlockIndex
from dask.utils import typename

from distributed import default_client, wait
from distributed.protocol import dask_deserialize, dask_serialize


class Handler:
"""Base class for format-specific checkpointing handlers

A ``Handler`` object will be responsible for a single partition.
"""

fmt: None | str = None # General format label

def __init__(self, path, backend, index, **kwargs):
self.path = path
self.backend = backend
self.index = index
self.kwargs = kwargs

Check warning on line 30 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L27-L30

Added lines #L27 - L30 were not covered by tests

@classmethod
def clean(cls, dirpath):
"""Clean the target directory"""
import shutil

Check warning on line 35 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L35

Added line #L35 was not covered by tests

if os.path.isdir(dirpath):
with contextlib.suppress(FileNotFoundError):
shutil.rmtree(dirpath)

Check warning on line 39 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L37-L39

Added lines #L37 - L39 were not covered by tests

@classmethod
def prepare(cls, dirpath):
"""Create the target directory"""
os.makedirs(dirpath, exist_ok=True)

Check warning on line 44 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L44

Added line #L44 was not covered by tests

@classmethod
def save(cls, part, path, index):
"""Persist the target partition to disk"""
raise NotImplementedError() # Logic depends on format

Check warning on line 49 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L49

Added line #L49 was not covered by tests

@classmethod
def get_indices(cls, path):
"""Return set of local indices"""
# Assume file-name is something like: <name>.<index>.<fmt>
return {int(fn.split(".")[-2]) for fn in glob.glob(path + f"/*.{cls.fmt}")}

Check warning on line 55 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L55

Added line #L55 was not covered by tests

def load(self):
"""Collect the saved partition"""
raise NotImplementedError() # Logic depends on format

Check warning on line 59 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L59

Added line #L59 was not covered by tests


@dask_serialize.register(Handler)
def _serialize_unloaded(obj):
# Make sure we read the partition into memory if
# this partition is moved to a different worker
return None, [pickle.dumps(obj.load())]

Check warning on line 66 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L66

Added line #L66 was not covered by tests


@dask_deserialize.register(Handler)
def _deserialize_unloaded(header, frames):
# Deserializing a `Handler` object returns the wrapped data
return pickle.loads(frames[0])

Check warning on line 72 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L72

Added line #L72 was not covered by tests


class ParquetHandler(Handler):
"""Parquet-specific checkpointing handler for DataFrame collections"""

fmt = "parquet"

@classmethod
def save(cls, part, path, index):
fn = f"{path}/part.{index[0]}.parquet"
part.to_parquet(fn)
return index

Check warning on line 84 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L82-L84

Added lines #L82 - L84 were not covered by tests

def load(self):
lib = import_module(self.backend)
fn = glob.glob(f"{self.path}/*.{self.index}.parquet")
return lib.read_parquet(fn, **self.kwargs)

Check warning on line 89 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L87-L89

Added lines #L87 - L89 were not covered by tests


class BaseCheckpoint:
"""Checkpoint a Dask collection on disk

The storage location does not need to be shared between workers.
"""

@classmethod
def create(cls, *args, **kwargs):
"""Create a new Checkpoint object"""
raise NotImplementedError()

Check warning on line 101 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L101

Added line #L101 was not covered by tests

def load(self):
"""Load a checkpointed collection

Note that this will not immediately persist the partitions
in memory. Rather, it will output a lazy Dask collection.
"""
raise NotImplementedError()

Check warning on line 109 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L109

Added line #L109 was not covered by tests

def clean(self):
"""Clean up this checkpoint"""
raise NotImplementedError()

Check warning on line 113 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L113

Added line #L113 was not covered by tests


class DataFrameCheckpoint(BaseCheckpoint):
"""Checkpoint a Dask DataFrame on disk"""

def __init__(
self,
npartitions,
meta,
handler,
path,
load_kwargs,
):
self.npartitions = npartitions
self.meta = meta
self.backend = typename(meta).partition(".")[0]
self.handler = handler
self.path = path
self.load_kwargs = load_kwargs or {}
self._valid = True

def __repr__(self):
path = self.path
fmt = self.handler.fmt
return f"DataFrameCheckpoint<path={path}, format={fmt}>"

Check warning on line 138 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L136-L138

Added lines #L136 - L138 were not covered by tests

@classmethod
def create(
cls,
df,
path,
format="parquet",
overwrite=True,
compute_kwargs=None,
load_kwargs=None,
**save_kwargs,
):
# Get handler
if format == "parquet":
handler = ParquetHandler
else:
# Only parquet supported for now
raise NotImplementedError()

Check warning on line 156 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L156

Added line #L156 was not covered by tests

client = default_client()

if overwrite:
wait(client.run(handler.clean, path))
wait(client.run(handler.prepare, path))

meta = df._meta.copy()
df.map_partitions(
handler.save,
path,
BlockIndex((df.npartitions,)),
meta=meta,
enforce_metadata=False,
**save_kwargs,
).compute(**(compute_kwargs or {}))

return cls(
df.npartitions,
meta,
handler,
path,
load_kwargs,
)

def load(self):
if not self._valid:
raise RuntimeError("This checkpoint is no longer valid")

Check warning on line 184 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L184

Added line #L184 was not covered by tests

#
# Get client and check workers
#
client = default_client()

#
# Find out which partition indices are stored on each worker
#
worker_indices = client.run(self.handler.get_indices, self.path)
summary = defaultdict(list)
for worker, indices in worker_indices.items():
for index in indices:
summary[index].append(worker)

# Check partition count
npartitions_found = len(summary)
if len(summary) != self.npartitions:
raise RuntimeError(

Check warning on line 203 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L203

Added line #L203 was not covered by tests
f"Expected {self.npartitions} partitions. "
f"Found {npartitions_found}."
)

#
# Convert each checkpointed partition to a `Handler` object
#
assignments = {}
futures = []
for i, (worker, indices) in enumerate(summary.items()):
assignments[worker] = indices[i % len(indices)]
futures.append(
client.submit(
self.handler,
self.path,
self.backend,
i,
workers=[assignments[i]],
**self.load_kwargs,
)
)
wait(futures)

#
# Crate a new collection from the delayed `Handler` objects
#
meta = self.meta
return dd.from_delayed(futures, meta=meta, verify_meta=False).map_partitions(
self._load_partition,
meta=meta,
)

@staticmethod
def _load_partition(obj):
# Load a checkpointed partition.
# Used by DataFrameCheckpoint.load
if isinstance(obj, Handler):
return obj.load()
return obj

Check warning on line 242 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L240-L242

Added lines #L240 - L242 were not covered by tests

def clean(self):
client = default_client()
wait(client.run(self.handler.clean, self.path))
self._valid = False

Check warning on line 247 in distributed/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

distributed/checkpoint.py#L245-L247

Added lines #L245 - L247 were not covered by tests
14 changes: 14 additions & 0 deletions distributed/tests/test_dask_collections.py
Expand Up @@ -252,3 +252,17 @@ def test_tuple_futures_arg(client, typ):
),
)
dd.assert_eq(df2.result().iloc[:0], make_time_dataframe().iloc[:0])


def test_dataframe_checkpoint(client, tmp_path):
from distributed.checkpoint import DataFrameCheckpoint

df = make_time_dataframe()
ddf = dd.from_pandas(df, npartitions=10)

ckpt = DataFrameCheckpoint.create(ddf, str(tmp_path))
client.cancel(ddf)
del ddf

# Must use distributed scheduler to compute
dd.assert_eq(df, ckpt.load(), scheduler=None)