Skip to content

Commit

Permalink
Improve data processing to enable downloading LAOIN 400M (#19452)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Feb 13, 2024
1 parent 3c5a465 commit b097a4d
Show file tree
Hide file tree
Showing 15 changed files with 275 additions and 255 deletions.
2 changes: 1 addition & 1 deletion requirements/data/test.txt
Expand Up @@ -5,5 +5,5 @@ pytest-timeout ==2.1.0
pytest-rerunfailures ==12.0
pytest-random-order ==1.1.0
viztracer
pandas
pyarrow
polars
7 changes: 7 additions & 0 deletions src/lightning/data/__init__.py
@@ -1,3 +1,5 @@
from lightning_utilities.core.imports import RequirementCache

from lightning.data.processing.functions import map, optimize, walk
from lightning.data.streaming.combined import CombinedStreamingDataset
from lightning.data.streaming.dataloader import StreamingDataLoader
Expand All @@ -13,3 +15,8 @@
"optimize",
"walk",
]

if RequirementCache('lightning_sdk'):
from lightning_sdk import Machine # noqa: F401

__all__.append("Machine")
17 changes: 11 additions & 6 deletions src/lightning/data/processing/data_processor.py
Expand Up @@ -372,7 +372,6 @@ def __init__(
self._counter = 0
self._last_time = time()
self._index_counter = 0
self._current_item: Any = None

def run(self) -> None:
try:
Expand Down Expand Up @@ -477,6 +476,7 @@ def _try_upload(self, data: Optional[Union[str, Tuple[str, str]]]) -> None:
assert os.path.exists(data), data
else:
assert os.path.exists(data[-1]), data

self.to_upload_queues[self._counter % self.num_uploaders].put(data)

def _collect_paths(self) -> None:
Expand Down Expand Up @@ -588,8 +588,8 @@ def _start_uploaders(self) -> None:

def _handle_data_chunk_recipe(self, index: int) -> None:
try:
self._current_item = self.items[index] if self.reader is None else self.reader.read(self.items[index])
item_data_or_generator = self.data_recipe.prepare_item(self._current_item)
current_item = self.items[index] if self.reader is None else self.reader.read(self.items[index])
item_data_or_generator = self.data_recipe.prepare_item(current_item)
if isinstance(item_data_or_generator, types.GeneratorType):
for item_data in item_data_or_generator:
if item_data is not None:
Expand Down Expand Up @@ -713,14 +713,19 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul
size = sum([c["dim"] if c["dim"] is not None else c["chunk_size"] for c in config["chunks"]])
num_bytes = sum([c["chunk_bytes"] for c in config["chunks"]])
data_format = tree_unflatten(config["config"]["data_format"], treespec_loads(config["config"]["data_spec"]))
num_chunks = len(config["chunks"])

# The platform can't store more than 1024 entries.
# Note: This isn't really used right now, so it is fine to skip if too big.
num_bytes_per_chunk = [c["chunk_size"] for c in config["chunks"]] if num_chunks < 1024 else []

return _Result(
size=size,
num_bytes=num_bytes,
data_format=data_format,
compression=config["config"]["compression"],
num_chunks=len(config["chunks"]),
num_bytes_per_chunk=[c["chunk_size"] for c in config["chunks"]],
num_bytes_per_chunk=num_bytes_per_chunk,
)
return _Result(
size=size,
Expand Down Expand Up @@ -866,9 +871,9 @@ def run(self, data_recipe: DataRecipe) -> None:
raise ValueError("The `prepare_structure` should return a list of item metadata.")

if self.reader:
workers_user_items = self.reader.items_to_workers(user_items, self.num_workers)
user_items = self.reader.remap_items(user_items, self.num_workers)

elif self.weights is not None:
if self.weights is not None:
if len(self.weights) != len(user_items):
raise ValueError("The provided weights length should match the inputs' length.")
workers_user_items = _map_items_to_workers_weighted(
Expand Down
47 changes: 0 additions & 47 deletions src/lightning/data/processing/dns.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/lightning/data/processing/functions.py
Expand Up @@ -24,8 +24,8 @@

from lightning.data.constants import _IS_IN_STUDIO, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from lightning.data.processing.dns import optimize_dns_context
from lightning.data.processing.readers import BaseReader
from lightning.data.processing.utilities import optimize_dns_context
from lightning.data.streaming.resolver import (
Dir,
_assert_dir_has_index_file,
Expand Down
47 changes: 0 additions & 47 deletions src/lightning/data/processing/image.py

This file was deleted.

128 changes: 48 additions & 80 deletions src/lightning/data/processing/readers.py
@@ -1,17 +1,13 @@
import contextlib
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List, Optional
from typing import Any, List

from lightning_utilities.core.imports import RequirementCache
from tqdm import tqdm

from lightning.data.utilities.env import _DistributedEnv
from lightning.data.utilities.shuffle import _associate_chunks_and_internals_to_ranks

_POLARS_AVAILABLE = RequirementCache("polars")
_PYARROW_AVAILABLE = RequirementCache("pyarrow")


class BaseReader(ABC):

def get_num_nodes(self) -> int:
Expand All @@ -21,9 +17,8 @@ def get_node_rank(self) -> int:
return int(os.getenv("DATA_OPTIMIZER_NODE_RANK", 0))

@abstractmethod
def items_to_workers(self, items: List[Any], num_workers: int) -> List[List[Any]]:
"""This method is meant to convert the items provided by the users into items to be processed by the
workers."""
def remap_items(self, items: List[Any], num_workers: int) -> List[Any]:
"""This method is meant to remap the items provided by the users into items more adapted to be distributed."""
pass

@abstractmethod
Expand All @@ -32,100 +27,73 @@ def read(self, item: Any) -> Any:
pass


@dataclass
class ParquetSlice:
"""Keep track of a parquet file slice with its filepath, start and end."""
filepath: str
start: int
end: int


class ParquetReader(BaseReader):

def __init__(self, num_rows: Optional[int] = 2048, to_pandas: bool = True) -> None:
def __init__(self, cache_folder: str, num_rows: int = 65536, to_pandas: bool = True) -> None:
super().__init__()
self.cache_folder = cache_folder
self.num_rows = num_rows
self.to_pandas = to_pandas

if not _PYARROW_AVAILABLE or not _POLARS_AVAILABLE:
raise ModuleNotFoundError("Please, run: `pip install pyarrow polars`")

def _get_num_rows(self, path: str) -> int:
if _PYARROW_AVAILABLE:
import pyarrow.dataset as ds
df = ds.dataset(path).scanner()
return df.count_rows()

# FIXED: There is a bug in polars. This leads to read_parquet to hang.
if _POLARS_AVAILABLE:
import polars as pol
df = pol.scan_parquet(path)
num_rows = df.select(pol.len()).collect().item()
return num_rows
if not _PYARROW_AVAILABLE:
raise ModuleNotFoundError("Please, run: `pip install pyarrow`")

raise RuntimeError("Please, install either pyarrow or polars.")

def read(self, item: ParquetSlice) -> Any:
if _POLARS_AVAILABLE:
import polars as pol
df = pol.scan_parquet(item.filepath).slice(item.start, item.end).collect()
self.parquet_file = None

if self.to_pandas:
df = df.to_pandas()
def _get_num_rows(self, path: str) -> int:
import pyarrow.dataset as ds

return df
df = ds.dataset(path).scanner()
return df.count_rows()

if _PYARROW_AVAILABLE:
import pyarrow.dataset as ds
def read(self, filepath: str) -> Any:
import pyarrow as pa
import pyarrow.parquet as pq

df = ds.dataset(item.filepath).scanner()
# Try to force dellocation to avoid memory leak
with contextlib.suppress(Exception):
pa.jemalloc_set_decay_ms(0)

df = df.take([item.start, item.end])
# close the previous parquet file to release the memory
if self.parquet_file is not None:
self.parquet_file.close()
self.parquet_file = None

if self.to_pandas:
df.to_pandas()
self.parquet_file = pq.ParquetFile(filepath, memory_map=True)
return self.parquet_file

return df
def remap_items(self, filepaths: List[str], _: int) -> List[str]:
import pyarrow.parquet as pq

raise RuntimeError("Please, install either pyarrow or polars.")
print("Starting resharding the parquet files for optimized processing.")

new_items = []

def items_to_workers(self, items: Any, num_workers: int) -> List[List[ParquetSlice]]:
intervals = [(0, self._get_num_rows(item)) for item in items]
cache_folder = os.path.join(self.cache_folder, f"{self.num_rows}")
os.makedirs(cache_folder, exist_ok=True)

world_size = self.get_num_nodes() * num_workers
node_rank = self.get_node_rank()
for filepath in filepaths:
num_rows = self._get_num_rows(filepath)

fake_distributed_env = _DistributedEnv(world_size, 0, self.get_num_nodes())
parquet_indexes_per_worker, p_slices_per_worker = _associate_chunks_and_internals_to_ranks(
fake_distributed_env, list(range(len(items))), intervals, False)
table = None
parquet_filename = os.path.basename(filepath)

workers_user_items: List[List[ParquetSlice]] = [[] for _ in range(num_workers)]
for start in tqdm(range(0, num_rows, self.num_rows)):
end = min(start + self.num_rows, num_rows)
chunk_filepath = os.path.join(cache_folder, f"{start}_{end}_{parquet_filename}")
new_items.append(chunk_filepath)

iterator = enumerate(zip(parquet_indexes_per_worker, p_slices_per_worker))
if os.path.exists(chunk_filepath):
continue

node_start = node_rank * num_workers
node_end = (node_rank + 1) * num_workers
if table is None:
table = pq.read_table(filepath, memory_map=True)

for worker_idx, (parquet_indexes, p_slices) in iterator:
if node_start <= worker_idx < node_end:
if self.num_rows:
workers_user_items[worker_idx % num_workers].extend([
ParquetSlice(
items[parquet_index], p_slice_start, p_slice_start + self.num_rows
if p_slice[1] > (p_slice_start + self.num_rows) else
p_slice[1]
)
for parquet_index, p_slice in zip(parquet_indexes, p_slices)
for p_slice_start in range(p_slice[0], p_slice[1] + self.num_rows, self.num_rows)
if p_slice_start < p_slice[1]
])
else:
workers_user_items[worker_idx % num_workers].extend([
ParquetSlice(items[parquet_index], *p_slice)
for parquet_index, p_slice in zip(parquet_indexes, p_slices)
])
pq.write_table(table[start: end], chunk_filepath)

assert len(workers_user_items) == num_workers
assert all(len(w) for w in workers_user_items)
print("Finished resharding the parquet files for optimized processing.")

return workers_user_items
return new_items

0 comments on commit b097a4d

Please sign in to comment.