From c30edfe1475008155dbc9ef5f52f6d4ac66c81ac Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 14 Oct 2022 14:00:50 +0200 Subject: [PATCH 01/11] fix iter_batches --- src/datasets/arrow_dataset.py | 23 +++++++++++------------ src/datasets/table.py | 23 +++++++++++++++++++++++ 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index d3db912357b..24070f70688 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -95,6 +95,7 @@ Table, concat_tables, embed_table_storage, + table_iter_batches, list_table_cache_files, table_cast, table_visitor, @@ -1924,7 +1925,7 @@ def __len__(self): """ return self.num_rows - def _iter_batches(self, batch_size: int, decoded: bool = True): + def _iter_batches(self, batch_size: int, decoded: bool = True, drop_last_batch: bool = False): """Iterate through the batches of size `batch_size`. If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the @@ -1935,16 +1936,15 @@ def _iter_batches(self, batch_size: int, decoded: bool = True): # Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch) format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs) - for batch in self.data.to_reader(max_chunksize=batch_size): - pa_subtable = pa.Table.from_batches([batch]) - formatted_output = format_table( + for pa_subtable in table_iter_batches(self.data, batch_size=batch_size, drop_last_batch=drop_last_batch): + formatted_batch = format_table( pa_subtable, range(pa_subtable.num_rows), formatter=formatter, format_columns=self._format_columns, output_all_columns=self._output_all_columns, ) - yield formatted_output + yield formatted_batch else: for i in range(0, self.num_rows, batch_size): yield self._getitem( @@ -1964,12 +1964,11 @@ def _iter(self, decoded: bool = True): format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs) batch_size = config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER - for batch in self.data.to_reader(max_chunksize=batch_size): - for i in range(batch.num_rows): - batch_ex = batch.slice(i, 1) - pa_subtable = pa.Table.from_batches([batch_ex]) + for pa_subtable in table_iter_batches(self.data, batch_size=batch_size): + for i in range(pa_subtable.num_rows): + pa_subtable_ex = pa_subtable.slice(i, 1) formatted_output = format_table( - pa_subtable, + pa_subtable_ex, 0, formatter=formatter, format_columns=self._format_columns, @@ -2936,8 +2935,8 @@ def init_buffer_and_writer(): len(input_dataset) if not drop_last_batch else len(input_dataset) // batch_size * batch_size ) pbar_total = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size - pbar_iterable = itertools.islice( - input_dataset._iter_batches(batch_size, decoded=False), pbar_total + pbar_iterable = input_dataset._iter_batches( + batch_size, decoded=False, drop_last_batch=drop_last_batch ) pbar_unit = "ex" if not batched else "ba" pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc diff --git a/src/datasets/table.py b/src/datasets/table.py index ed56c195172..f1f490011a6 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -2148,3 +2148,26 @@ def _visit(array, feature): for name, feature in features.items(): _visit(table[name], feature) + + +def table_iter_batches(pa_table: pa.Table, batch_size: int, drop_last_batch=False): + chunks_buffer = [] + chunks_buffer_size = 0 + for chunk in pa_table.to_reader(max_chunksize=batch_size): + if chunks_buffer_size + len(chunk) < batch_size: + chunks_buffer.append(chunk) + chunks_buffer_size += len(chunk) + continue + elif chunks_buffer_size + len(chunk) == batch_size: + chunks_buffer.append(chunk) + yield pa.Table.from_batches(chunks_buffer) + chunks_buffer = [] + chunks_buffer_size = 0 + else: + cropped_chunk_length = batch_size - chunks_buffer_size + chunks_buffer.append(chunk.slice(0, cropped_chunk_length)) + yield pa.Table.from_batches(chunks_buffer) + chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)] + chunks_buffer_size = cropped_chunk_length + if not drop_last_batch: + yield pa.Table.from_batches(chunks_buffer) From ecf353a8188b3e4c767b5981e5943a99f4351de0 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 14 Oct 2022 14:08:06 +0200 Subject: [PATCH 02/11] minor --- src/datasets/arrow_dataset.py | 2 +- src/datasets/table.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 24070f70688..3ee97e7bf05 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -95,9 +95,9 @@ Table, concat_tables, embed_table_storage, - table_iter_batches, list_table_cache_files, table_cast, + table_iter_batches, table_visitor, ) from .tasks import TaskTemplate diff --git a/src/datasets/table.py b/src/datasets/table.py index f1f490011a6..f0674b66992 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -2154,7 +2154,9 @@ def table_iter_batches(pa_table: pa.Table, batch_size: int, drop_last_batch=Fals chunks_buffer = [] chunks_buffer_size = 0 for chunk in pa_table.to_reader(max_chunksize=batch_size): - if chunks_buffer_size + len(chunk) < batch_size: + if len(chunk) == 0: + continue + elif chunks_buffer_size + len(chunk) < batch_size: chunks_buffer.append(chunk) chunks_buffer_size += len(chunk) continue @@ -2169,5 +2171,5 @@ def table_iter_batches(pa_table: pa.Table, batch_size: int, drop_last_batch=Fals yield pa.Table.from_batches(chunks_buffer) chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)] chunks_buffer_size = cropped_chunk_length - if not drop_last_batch: + if not drop_last_batch and chunks_buffer: yield pa.Table.from_batches(chunks_buffer) From 3bdf8ef69b8b3b72e482159c9b6c6501626f60ab Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 14 Oct 2022 15:37:30 +0200 Subject: [PATCH 03/11] fix tests --- src/datasets/table.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/datasets/table.py b/src/datasets/table.py index f0674b66992..8bd5a0bd67a 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -2151,6 +2151,13 @@ def _visit(array, feature): def table_iter_batches(pa_table: pa.Table, batch_size: int, drop_last_batch=False): + """Iterate ober sub-tables of size `batch_size`. + + Args: + table (:obj:`pyarrow.Table`): PyArrow table to visit + batch_size (:obj:`int`): size of each sub-table to yield + drop_last_batch (:obj:`bool`, default `False`): Drop the last batch if it is smaller than `batch_size` + """ chunks_buffer = [] chunks_buffer_size = 0 for chunk in pa_table.to_reader(max_chunksize=batch_size): @@ -2170,6 +2177,6 @@ def table_iter_batches(pa_table: pa.Table, batch_size: int, drop_last_batch=Fals chunks_buffer.append(chunk.slice(0, cropped_chunk_length)) yield pa.Table.from_batches(chunks_buffer) chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)] - chunks_buffer_size = cropped_chunk_length + chunks_buffer_size = len(chunk) - cropped_chunk_length if not drop_last_batch and chunks_buffer: yield pa.Table.from_batches(chunks_buffer) From 60e4a767f7f29838b27c8677ac8b3a7daa6e9aab Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 14 Oct 2022 15:37:35 +0200 Subject: [PATCH 04/11] add more tests --- tests/test_table.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/test_table.py b/tests/test_table.py index 4bb31900dea..da9838b2916 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -24,6 +24,7 @@ embed_table_storage, inject_arrow_table_documentation, table_cast, + table_iter_batches, ) from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, slow @@ -1118,3 +1119,30 @@ def test_embed_table_storage(image_file): embedded_images_table = embed_table_storage(table) assert embedded_images_table.to_pydict()["image"][0]["path"] is None assert isinstance(embedded_images_table.to_pydict()["image"][0]["bytes"], bytes) + + +@pytest.mark.parametrize( + "pa_table", + [ + pa.table({"foo": range(10)}), + pa.concat_tables([pa.table({"foo": range(0, 5)}), pa.table({"foo": range(5, 10)})]), + pa.concat_tables([pa.table({"foo": [i]}) for i in range(10)]), + ], +) +@pytest.mark.parametrize("batch_size", [1, 2, 3, 9, 10, 11, 20]) +@pytest.mark.parametrize("drop_last_batch", [False, True]) +def test_table_iter_batches(pa_table, batch_size, drop_last_batch): + num_rows = len(pa_table) if not drop_last_batch else len(pa_table) // batch_size * batch_size + num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size + subtables = list(table_iter_batches(pa_table, batch_size=batch_size, drop_last_batch=drop_last_batch)) + assert len(subtables) == num_batches + if drop_last_batch: + if not all(len(subtable) == batch_size for subtable in subtables): + raise ArithmeticError([len(subtable) for subtable in subtables]) + assert all(len(subtable) == batch_size for subtable in subtables) + else: + assert all(len(subtable) == batch_size for subtable in subtables[:-1]) + assert len(subtables[-1]) <= batch_size + if num_rows > 0: + reloaded = pa.concat_tables(subtables) + assert pa_table.slice(0, num_rows).to_pydict() == reloaded.to_pydict() From 30059edb6433029e5c16585bf1db4280b926bb20 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 14 Oct 2022 15:45:15 +0200 Subject: [PATCH 05/11] docs + rename --- src/datasets/arrow_dataset.py | 6 +++--- src/datasets/table.py | 4 ++-- tests/test_table.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 3ee97e7bf05..cb32908515c 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -97,7 +97,7 @@ embed_table_storage, list_table_cache_files, table_cast, - table_iter_batches, + table_iter, table_visitor, ) from .tasks import TaskTemplate @@ -1936,7 +1936,7 @@ def _iter_batches(self, batch_size: int, decoded: bool = True, drop_last_batch: # Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch) format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs) - for pa_subtable in table_iter_batches(self.data, batch_size=batch_size, drop_last_batch=drop_last_batch): + for pa_subtable in table_iter(self.data, batch_size=batch_size, drop_last_batch=drop_last_batch): formatted_batch = format_table( pa_subtable, range(pa_subtable.num_rows), @@ -1964,7 +1964,7 @@ def _iter(self, decoded: bool = True): format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs) batch_size = config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER - for pa_subtable in table_iter_batches(self.data, batch_size=batch_size): + for pa_subtable in table_iter(self.data, batch_size=batch_size): for i in range(pa_subtable.num_rows): pa_subtable_ex = pa_subtable.slice(i, 1) formatted_output = format_table( diff --git a/src/datasets/table.py b/src/datasets/table.py index 8bd5a0bd67a..4939eafbab4 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -2150,11 +2150,11 @@ def _visit(array, feature): _visit(table[name], feature) -def table_iter_batches(pa_table: pa.Table, batch_size: int, drop_last_batch=False): +def table_iter(pa_table: pa.Table, batch_size: int, drop_last_batch=False): """Iterate ober sub-tables of size `batch_size`. Args: - table (:obj:`pyarrow.Table`): PyArrow table to visit + table (:obj:`pyarrow.Table`): PyArrow table to iterate over batch_size (:obj:`int`): size of each sub-table to yield drop_last_batch (:obj:`bool`, default `False`): Drop the last batch if it is smaller than `batch_size` """ diff --git a/tests/test_table.py b/tests/test_table.py index da9838b2916..40176d20a99 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -24,7 +24,7 @@ embed_table_storage, inject_arrow_table_documentation, table_cast, - table_iter_batches, + table_iter, ) from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, slow @@ -1131,10 +1131,10 @@ def test_embed_table_storage(image_file): ) @pytest.mark.parametrize("batch_size", [1, 2, 3, 9, 10, 11, 20]) @pytest.mark.parametrize("drop_last_batch", [False, True]) -def test_table_iter_batches(pa_table, batch_size, drop_last_batch): +def test_table_iter(pa_table, batch_size, drop_last_batch): num_rows = len(pa_table) if not drop_last_batch else len(pa_table) // batch_size * batch_size num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size - subtables = list(table_iter_batches(pa_table, batch_size=batch_size, drop_last_batch=drop_last_batch)) + subtables = list(table_iter(pa_table, batch_size=batch_size, drop_last_batch=drop_last_batch)) assert len(subtables) == num_batches if drop_last_batch: if not all(len(subtable) == batch_size for subtable in subtables): From 0600ef9e42d808b644e591cbedbaed359cfdc298 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 14 Oct 2022 15:47:46 +0200 Subject: [PATCH 06/11] remove tmp lines --- tests/test_table.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_table.py b/tests/test_table.py index 40176d20a99..d20a26363d5 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -1137,8 +1137,6 @@ def test_table_iter(pa_table, batch_size, drop_last_batch): subtables = list(table_iter(pa_table, batch_size=batch_size, drop_last_batch=drop_last_batch)) assert len(subtables) == num_batches if drop_last_batch: - if not all(len(subtable) == batch_size for subtable in subtables): - raise ArithmeticError([len(subtable) for subtable in subtables]) assert all(len(subtable) == batch_size for subtable in subtables) else: assert all(len(subtable) == batch_size for subtable in subtables[:-1]) From 1856601ea4a7d8fe6131bc65ccb2b40ead2aef0e Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 14 Oct 2022 16:00:09 +0200 Subject: [PATCH 07/11] run test only if pyarrow>=8 --- tests/test_table.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_table.py b/tests/test_table.py index d20a26363d5..b622d089476 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -6,6 +6,7 @@ import pyarrow as pa import pytest +import datasets from datasets import Sequence, Value from datasets.features.features import ClassLabel, Features, Image from datasets.table import ( @@ -1121,6 +1122,7 @@ def test_embed_table_storage(image_file): assert isinstance(embedded_images_table.to_pydict()["image"][0]["bytes"], bytes) +@pytest.mark.skipif(datasets.config.PYARROW_VERSION.major < 8, reason="only available on pyarrow>=8") @pytest.mark.parametrize( "pa_table", [ From e4ae9cd94c7ca930593887bc5d0e869175e919e9 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 14 Oct 2022 16:02:10 +0200 Subject: [PATCH 08/11] add error message for old versions of pyarrow --- src/datasets/table.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/datasets/table.py b/src/datasets/table.py index 4939eafbab4..376e1e960f5 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -2158,6 +2158,8 @@ def table_iter(pa_table: pa.Table, batch_size: int, drop_last_batch=False): batch_size (:obj:`int`): size of each sub-table to yield drop_last_batch (:obj:`bool`, default `False`): Drop the last batch if it is smaller than `batch_size` """ + if config.PYARROW_VERSION.major < 8: + raise RuntimeError(f"pyarrow>=8.0.0 is needed to use table_iter but you have {config.PYARROW_VERSION}") chunks_buffer = [] chunks_buffer_size = 0 for chunk in pa_table.to_reader(max_chunksize=batch_size): From ed7631d05d4ce5698c12fd8550946e7589ceb30b Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 14 Oct 2022 16:03:26 +0200 Subject: [PATCH 09/11] mention pyarrow>=8 in docstring as well --- src/datasets/table.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/datasets/table.py b/src/datasets/table.py index 376e1e960f5..b5e039faabf 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -2152,6 +2152,8 @@ def _visit(array, feature): def table_iter(pa_table: pa.Table, batch_size: int, drop_last_batch=False): """Iterate ober sub-tables of size `batch_size`. + + Requires pyarrow>=8.0.0 Args: table (:obj:`pyarrow.Table`): PyArrow table to iterate over From 4ac0c1b3e9bfb57218eb67be59d5cfc587e3e136 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 14 Oct 2022 16:03:43 +0200 Subject: [PATCH 10/11] style --- src/datasets/table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/table.py b/src/datasets/table.py index b5e039faabf..4fb0a0002e1 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -2152,7 +2152,7 @@ def _visit(array, feature): def table_iter(pa_table: pa.Table, batch_size: int, drop_last_batch=False): """Iterate ober sub-tables of size `batch_size`. - + Requires pyarrow>=8.0.0 Args: From d5f243ccc9f16f3a19bdfb69f7a084e066daff6e Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Fri, 14 Oct 2022 16:43:28 +0200 Subject: [PATCH 11/11] Update src/datasets/table.py Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> --- src/datasets/table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/table.py b/src/datasets/table.py index 4fb0a0002e1..f585d72efb0 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -2151,7 +2151,7 @@ def _visit(array, feature): def table_iter(pa_table: pa.Table, batch_size: int, drop_last_batch=False): - """Iterate ober sub-tables of size `batch_size`. + """Iterate over sub-tables of size `batch_size`. Requires pyarrow>=8.0.0