diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 92267650cbb..42479e42091 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1854,17 +1854,64 @@ def __len__(self): """ return self.num_rows + def _iter_batches(self, batch_size: int, decoded: bool = True): + """Iterate through the batches of size `batch_size`. + + If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the + selected format. + """ + if self._indices is None and config.PYARROW_VERSION.major >= 8: + # Fast iteration + # Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch) + format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} + formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs) + for batch in self.data.to_reader(max_chunksize=batch_size): + pa_subtable = pa.Table.from_batches([batch]) + formatted_output = format_table( + pa_subtable, + range(pa_subtable.num_rows), + formatter=formatter, + format_columns=self._format_columns, + output_all_columns=self._output_all_columns, + ) + yield formatted_output + else: + for i in range(0, self.num_rows, batch_size): + yield self._getitem( + slice(i, i + batch_size), + decoded=decoded, + ) + def _iter(self, decoded: bool = True): """Iterate through the examples. If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the selected format. """ - for index in range(self.num_rows): - yield self._getitem( - index, - decoded=decoded, - ) + if self._indices is None and config.PYARROW_VERSION.major >= 8: + # Fast iteration + # Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch) + format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} + formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs) + batch_size = config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER + for batch in self.data.to_reader(max_chunksize=batch_size): + for i in range(batch.num_rows): + batch_ex = batch.slice(i, 1) + pa_subtable = pa.Table.from_batches([batch_ex]) + formatted_output = format_table( + pa_subtable, + 0, + formatter=formatter, + format_columns=self._format_columns, + output_all_columns=self._output_all_columns, + ) + yield formatted_output + else: + for i in range(self.num_rows): + yield self._getitem( + i, + decoded=decoded, + ) def __iter__(self): """Iterate through the examples. @@ -2805,14 +2852,16 @@ def init_buffer_and_writer(): # Loop over single examples or batches and write to buffer/file if examples are to be updated if not batched: - pbar_iterable = input_dataset._iter(decoded=False) pbar_total = len(input_dataset) + pbar_iterable = input_dataset._iter(decoded=False) else: num_rows = ( len(input_dataset) if not drop_last_batch else len(input_dataset) // batch_size * batch_size ) - pbar_iterable = range(0, num_rows, batch_size) pbar_total = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size + pbar_iterable = itertools.islice( + input_dataset._iter_batches(batch_size, decoded=False), pbar_total + ) pbar_unit = "ex" if not batched else "ba" pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc pbar = logging.tqdm( @@ -2835,11 +2884,7 @@ def init_buffer_and_writer(): else: writer.write(example) else: - for i in pbar: - batch = input_dataset._getitem( - slice(i, i + batch_size), - decoded=False, - ) + for i, batch in enumerate(pbar): indices = list( range(*(slice(i, i + batch_size).indices(input_dataset.num_rows))) ) # Something simpler? diff --git a/src/datasets/config.py b/src/datasets/config.py index 2bd5419cbe3..8595d5fc581 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -168,6 +168,9 @@ # https://github.com/apache/arrow/blob/master/docs/source/cpp/arrays.rst#size-limitations-and-recommendations) DEFAULT_MAX_BATCH_SIZE = 10_000 +# Size of the preloaded record batch in `Dataset.__iter__` +ARROW_READER_BATCH_SIZE_IN_DATASET_ITER = 10 + # Pickling tables works only for small tables (<4GiB) # For big tables, we write them on disk instead MAX_TABLE_NBYTES_FOR_PICKLING = 4 << 30 diff --git a/src/datasets/table.py b/src/datasets/table.py index 6ea551626b4..ed56c195172 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -330,6 +330,30 @@ def to_pandas(self, *args, **kwargs): def to_string(self, *args, **kwargs): return self.table.to_string(*args, **kwargs) + def to_reader(self, *args, **kwargs): + """ + Convert the Table to a RecordBatchReader. + + Note that this method is zero-copy, it merely exposes the same data under a different API. + + Args: + max_chunksize (:obj:`int`, defaults to :obj:`None`) + Maximum size for RecordBatch chunks. Individual chunks may be smaller depending + on the chunk layout of individual columns. + + Returns: + :obj:`pyarrow.RecordBatchReader` + + + + pyarrow >= 8.0.0 needs to be installed to use this method. + + + """ + if config.PYARROW_VERSION.major < 8: + raise NotImplementedError("`pyarrow>=8.0.0` is required to use this method") + return self.table.to_reader(*args, **kwargs) + def field(self, *args, **kwargs): """ Select a schema field by its column name or numeric index.