Skip to content

Commit

Permalink
add num_process to load_from_disk huggingface#2252
Browse files Browse the repository at this point in the history
  • Loading branch information
kkoutini committed Dec 1, 2023
1 parent 2d31e43 commit 9bcd119
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# Lint as: python3
""" Simple Dataset wrapping an Arrow Table."""

import concurrent
import contextlib
import copy
import fnmatch
Expand Down Expand Up @@ -1595,6 +1596,7 @@ def load_from_disk(
dataset_path: str,
fs="deprecated",
keep_in_memory: Optional[bool] = None,
num_proc: Optional[int] = None,
storage_options: Optional[dict] = None,
) -> "Dataset":
"""
Expand All @@ -1620,6 +1622,9 @@ def load_from_disk(
dataset will not be copied in-memory unless explicitly enabled by setting
`datasets.config.IN_MEMORY_MAX_SIZE` to nonzero. See more details in the
[improve performance](../cache#improve-performance) section.
num_proc (`int`, *optional*):
Number of processes when downloading and generating the dataset locally.
Multiprocessing is disabled by default.
storage_options (`dict`, *optional*):
Key/value pairs to be passed on to the file-system backend, if any.
Expand Down Expand Up @@ -1698,10 +1703,32 @@ def load_from_disk(
)
keep_in_memory = keep_in_memory if keep_in_memory is not None else is_small_dataset(dataset_size)
table_cls = InMemoryTable if keep_in_memory else MemoryMappedTable
arrow_table = concat_tables(
table_cls.from_file(posixpath.join(dest_dataset_path, data_file["filename"]))
for data_file in state["_data_files"]
)
num_proc = num_proc if num_proc is not None else 1
if num_proc > 1:
pbar = hf_tqdm(
unit="shards",
total=len(state["_data_files"]),
desc="Loading the dataset from disk",
)
tables = []
# Using threads for faster collecting of loaded tables compared to processes
with concurrent.futures.ThreadPoolExecutor(max_workers=num_proc) as pool:
with pbar:
for table in pool.map(
table_cls.from_file,
[
posixpath.join(dest_dataset_path, data_file["filename"])
for data_file in state["_data_files"]
],
):
tables.append(table)
pbar.update(1)
arrow_table = concat_tables(tables)
else:
arrow_table = concat_tables(
table_cls.from_file(posixpath.join(dest_dataset_path, data_file["filename"]))
for data_file in state["_data_files"]
)

split = state["_split"]
split = Split(split) if split is not None else split
Expand Down

0 comments on commit 9bcd119

Please sign in to comment.