diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index d3db912357b..cb32908515c 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -97,6 +97,7 @@ embed_table_storage, list_table_cache_files, table_cast, + table_iter, table_visitor, ) from .tasks import TaskTemplate @@ -1924,7 +1925,7 @@ def __len__(self): """ return self.num_rows - def _iter_batches(self, batch_size: int, decoded: bool = True): + def _iter_batches(self, batch_size: int, decoded: bool = True, drop_last_batch: bool = False): """Iterate through the batches of size `batch_size`. If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the @@ -1935,16 +1936,15 @@ def _iter_batches(self, batch_size: int, decoded: bool = True): # Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch) format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs) - for batch in self.data.to_reader(max_chunksize=batch_size): - pa_subtable = pa.Table.from_batches([batch]) - formatted_output = format_table( + for pa_subtable in table_iter(self.data, batch_size=batch_size, drop_last_batch=drop_last_batch): + formatted_batch = format_table( pa_subtable, range(pa_subtable.num_rows), formatter=formatter, format_columns=self._format_columns, output_all_columns=self._output_all_columns, ) - yield formatted_output + yield formatted_batch else: for i in range(0, self.num_rows, batch_size): yield self._getitem( @@ -1964,12 +1964,11 @@ def _iter(self, decoded: bool = True): format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs) batch_size = config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER - for batch in self.data.to_reader(max_chunksize=batch_size): - for i in range(batch.num_rows): - batch_ex = batch.slice(i, 1) - pa_subtable = pa.Table.from_batches([batch_ex]) + for pa_subtable in table_iter(self.data, batch_size=batch_size): + for i in range(pa_subtable.num_rows): + pa_subtable_ex = pa_subtable.slice(i, 1) formatted_output = format_table( - pa_subtable, + pa_subtable_ex, 0, formatter=formatter, format_columns=self._format_columns, @@ -2936,8 +2935,8 @@ def init_buffer_and_writer(): len(input_dataset) if not drop_last_batch else len(input_dataset) // batch_size * batch_size ) pbar_total = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size - pbar_iterable = itertools.islice( - input_dataset._iter_batches(batch_size, decoded=False), pbar_total + pbar_iterable = input_dataset._iter_batches( + batch_size, decoded=False, drop_last_batch=drop_last_batch ) pbar_unit = "ex" if not batched else "ba" pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc diff --git a/src/datasets/table.py b/src/datasets/table.py index ed56c195172..f585d72efb0 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -2148,3 +2148,39 @@ def _visit(array, feature): for name, feature in features.items(): _visit(table[name], feature) + + +def table_iter(pa_table: pa.Table, batch_size: int, drop_last_batch=False): + """Iterate over sub-tables of size `batch_size`. + + Requires pyarrow>=8.0.0 + + Args: + table (:obj:`pyarrow.Table`): PyArrow table to iterate over + batch_size (:obj:`int`): size of each sub-table to yield + drop_last_batch (:obj:`bool`, default `False`): Drop the last batch if it is smaller than `batch_size` + """ + if config.PYARROW_VERSION.major < 8: + raise RuntimeError(f"pyarrow>=8.0.0 is needed to use table_iter but you have {config.PYARROW_VERSION}") + chunks_buffer = [] + chunks_buffer_size = 0 + for chunk in pa_table.to_reader(max_chunksize=batch_size): + if len(chunk) == 0: + continue + elif chunks_buffer_size + len(chunk) < batch_size: + chunks_buffer.append(chunk) + chunks_buffer_size += len(chunk) + continue + elif chunks_buffer_size + len(chunk) == batch_size: + chunks_buffer.append(chunk) + yield pa.Table.from_batches(chunks_buffer) + chunks_buffer = [] + chunks_buffer_size = 0 + else: + cropped_chunk_length = batch_size - chunks_buffer_size + chunks_buffer.append(chunk.slice(0, cropped_chunk_length)) + yield pa.Table.from_batches(chunks_buffer) + chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)] + chunks_buffer_size = len(chunk) - cropped_chunk_length + if not drop_last_batch and chunks_buffer: + yield pa.Table.from_batches(chunks_buffer) diff --git a/tests/test_table.py b/tests/test_table.py index 4bb31900dea..b622d089476 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -6,6 +6,7 @@ import pyarrow as pa import pytest +import datasets from datasets import Sequence, Value from datasets.features.features import ClassLabel, Features, Image from datasets.table import ( @@ -24,6 +25,7 @@ embed_table_storage, inject_arrow_table_documentation, table_cast, + table_iter, ) from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, slow @@ -1118,3 +1120,29 @@ def test_embed_table_storage(image_file): embedded_images_table = embed_table_storage(table) assert embedded_images_table.to_pydict()["image"][0]["path"] is None assert isinstance(embedded_images_table.to_pydict()["image"][0]["bytes"], bytes) + + +@pytest.mark.skipif(datasets.config.PYARROW_VERSION.major < 8, reason="only available on pyarrow>=8") +@pytest.mark.parametrize( + "pa_table", + [ + pa.table({"foo": range(10)}), + pa.concat_tables([pa.table({"foo": range(0, 5)}), pa.table({"foo": range(5, 10)})]), + pa.concat_tables([pa.table({"foo": [i]}) for i in range(10)]), + ], +) +@pytest.mark.parametrize("batch_size", [1, 2, 3, 9, 10, 11, 20]) +@pytest.mark.parametrize("drop_last_batch", [False, True]) +def test_table_iter(pa_table, batch_size, drop_last_batch): + num_rows = len(pa_table) if not drop_last_batch else len(pa_table) // batch_size * batch_size + num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size + subtables = list(table_iter(pa_table, batch_size=batch_size, drop_last_batch=drop_last_batch)) + assert len(subtables) == num_batches + if drop_last_batch: + assert all(len(subtable) == batch_size for subtable in subtables) + else: + assert all(len(subtable) == batch_size for subtable in subtables[:-1]) + assert len(subtables[-1]) <= batch_size + if num_rows > 0: + reloaded = pa.concat_tables(subtables) + assert pa_table.slice(0, num_rows).to_pydict() == reloaded.to_pydict()