Skip to content

Commit

Permalink
Add concurrent loading of shards to datasets.load_from_disk (#6464)
Browse files Browse the repository at this point in the history
* add threadmap to load_from_disk #2252

* Add threadmap to arrow_reader.read_files #2252

* remove old way of loading files

* sort imports

---------

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
  • Loading branch information
kkoutini and lhoestq committed Jan 26, 2024
1 parent adfe8f8 commit 65434e4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
11 changes: 9 additions & 2 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import pyarrow.compute as pc
from huggingface_hub import CommitInfo, CommitOperationAdd, CommitOperationDelete, DatasetCard, DatasetCardData, HfApi
from multiprocess import Pool
from tqdm.contrib.concurrent import thread_map

from . import config
from .arrow_reader import ArrowReader
Expand Down Expand Up @@ -1703,9 +1704,15 @@ def load_from_disk(
)
keep_in_memory = keep_in_memory if keep_in_memory is not None else is_small_dataset(dataset_size)
table_cls = InMemoryTable if keep_in_memory else MemoryMappedTable

arrow_table = concat_tables(
table_cls.from_file(posixpath.join(dest_dataset_path, data_file["filename"]))
for data_file in state["_data_files"]
thread_map(
table_cls.from_file,
[posixpath.join(dest_dataset_path, data_file["filename"]) for data_file in state["_data_files"]],
tqdm_class=hf_tqdm,
desc="Loading dataset from disk",
disable=len(state["_data_files"]) <= 16,
)
)

split = state["_split"]
Expand Down
15 changes: 11 additions & 4 deletions src/datasets/arrow_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,19 @@
import re
import shutil
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union

import pyarrow as pa
import pyarrow.parquet as pq
from tqdm.contrib.concurrent import thread_map

from .download.download_config import DownloadConfig
from .naming import _split_re, filenames_for_dataset_split
from .table import InMemoryTable, MemoryMappedTable, Table, concat_tables
from .utils import logging
from .utils import tqdm as hf_tqdm
from .utils.file_utils import cached_path


Expand Down Expand Up @@ -192,13 +195,17 @@ def _read_files(self, files, in_memory=False) -> Table:
"""
if len(files) == 0 or not all(isinstance(f, dict) for f in files):
raise ValueError("please provide valid file informations")
pa_tables = []
files = copy.deepcopy(files)
for f in files:
f["filename"] = os.path.join(self._path, f["filename"])
for f_dict in files:
pa_table: Table = self._get_table_from_filename(f_dict, in_memory=in_memory)
pa_tables.append(pa_table)

pa_tables = thread_map(
partial(self._get_table_from_filename, in_memory=in_memory),
files,
tqdm_class=hf_tqdm,
desc="Loading dataset shards",
disable=len(files) <= 16,
)
pa_tables = [t for t in pa_tables if len(t) > 0]
if not pa_tables and (self._info is None or self._info.features is None):
raise ValueError(
Expand Down

0 comments on commit 65434e4

Please sign in to comment.