Skip to content

Commit

Permalink
Add threadmap to arrow_reader.read_files huggingface#2252
Browse files Browse the repository at this point in the history
  • Loading branch information
kkoutini committed Dec 6, 2023
1 parent 555c8a3 commit 15e6206
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/datasets/arrow_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
""" Arrow ArrowReader."""

import copy
from functools import partial
import math
import os
import re
Expand All @@ -26,11 +27,13 @@

import pyarrow as pa
import pyarrow.parquet as pq
from tqdm.contrib.concurrent import thread_map

from .download.download_config import DownloadConfig
from .naming import _split_re, filenames_for_dataset_split
from .table import InMemoryTable, MemoryMappedTable, Table, concat_tables
from .utils import logging
from .utils import tqdm as hf_tqdm
from .utils.file_utils import cached_path


Expand Down Expand Up @@ -196,6 +199,14 @@ def _read_files(self, files, in_memory=False) -> Table:
files = copy.deepcopy(files)
for f in files:
f["filename"] = os.path.join(self._path, f["filename"])

pa_tables = thread_map(
partial(self._get_table_from_filename, in_memory=in_memory),
files,
tqdm_class=hf_tqdm,
desc="Loading dataset shards",
disable=len(files) <= 16,
)
for f_dict in files:
pa_table: Table = self._get_table_from_filename(f_dict, in_memory=in_memory)
pa_tables.append(pa_table)
Expand Down

0 comments on commit 15e6206

Please sign in to comment.