From 302502429c953ba5bf483ec4c22fd023eb5ab387 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Tue, 27 Sep 2022 17:09:50 +0200 Subject: [PATCH 1/4] Fast dataset iter --- src/datasets/arrow_dataset.py | 66 ++++++++++++++++++++++++++++------- src/datasets/table.py | 24 +++++++++++++ 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 92267650cbb..4b0c9537a3c 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1854,17 +1854,61 @@ def __len__(self): """ return self.num_rows + def _iter_batches(self, batch_size, decoded: bool = True): + """Iterate through the batches of size `batch_size`. + + If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the + selected format. + """ + if self._indices is None and config.PYARROW_VERSION.major >= 8: + # Fast iteration + format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} + formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs) + schema = self.data._schema + for i, batch in enumerate(self.data.to_reader(max_chunksize=batch_size)): + pa_subtable = pa.Table.from_batches([batch], schema=schema) + formatted_output = format_table( + pa_subtable, + slice(i, i + batch_size), + formatter=formatter, + format_columns=self._format_columns, + output_all_columns=self._output_all_columns, + ) + yield formatted_output + else: + for i in range(0, self.num_rows, batch_size): + yield self._getitem( + slice(i, i + batch_size), + decoded=decoded, + ) + def _iter(self, decoded: bool = True): """Iterate through the examples. If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the selected format. """ - for index in range(self.num_rows): - yield self._getitem( - index, - decoded=decoded, - ) + if self._indices is None and config.PYARROW_VERSION.major >= 8: + # Fast iteration + format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} + formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs) + schema = self.data._schema + for i, batch in enumerate(self.data.to_reader(max_chunksize=1)): + pa_subtable = pa.Table.from_batches([batch], schema=schema) + formatted_output = format_table( + pa_subtable, + i, + formatter=formatter, + format_columns=self._format_columns, + output_all_columns=self._output_all_columns, + ) + yield formatted_output + else: + for i in range(self.num_rows): + yield self._getitem( + i, + decoded=decoded, + ) def __iter__(self): """Iterate through the examples. @@ -2805,14 +2849,16 @@ def init_buffer_and_writer(): # Loop over single examples or batches and write to buffer/file if examples are to be updated if not batched: - pbar_iterable = input_dataset._iter(decoded=False) pbar_total = len(input_dataset) + pbar_iterable = input_dataset._iter(decoded=False) else: num_rows = ( len(input_dataset) if not drop_last_batch else len(input_dataset) // batch_size * batch_size ) - pbar_iterable = range(0, num_rows, batch_size) pbar_total = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size + pbar_iterable = itertools.islice( + input_dataset._iter_batches(batch_size, decoded=False), pbar_total + ) pbar_unit = "ex" if not batched else "ba" pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc pbar = logging.tqdm( @@ -2835,11 +2881,7 @@ def init_buffer_and_writer(): else: writer.write(example) else: - for i in pbar: - batch = input_dataset._getitem( - slice(i, i + batch_size), - decoded=False, - ) + for i, batch in enumerate(pbar): indices = list( range(*(slice(i, i + batch_size).indices(input_dataset.num_rows))) ) # Something simpler? diff --git a/src/datasets/table.py b/src/datasets/table.py index 6ea551626b4..ed56c195172 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -330,6 +330,30 @@ def to_pandas(self, *args, **kwargs): def to_string(self, *args, **kwargs): return self.table.to_string(*args, **kwargs) + def to_reader(self, *args, **kwargs): + """ + Convert the Table to a RecordBatchReader. + + Note that this method is zero-copy, it merely exposes the same data under a different API. + + Args: + max_chunksize (:obj:`int`, defaults to :obj:`None`) + Maximum size for RecordBatch chunks. Individual chunks may be smaller depending + on the chunk layout of individual columns. + + Returns: + :obj:`pyarrow.RecordBatchReader` + + + + pyarrow >= 8.0.0 needs to be installed to use this method. + + + """ + if config.PYARROW_VERSION.major < 8: + raise NotImplementedError("`pyarrow>=8.0.0` is required to use this method") + return self.table.to_reader(*args, **kwargs) + def field(self, *args, **kwargs): """ Select a schema field by its column name or numeric index. From ae71a31211233c57b4f84d4c6423703394e0f77b Mon Sep 17 00:00:00 2001 From: mariosasko Date: Thu, 29 Sep 2022 14:40:22 +0200 Subject: [PATCH 2/4] Final improvements + some minor fixes --- src/datasets/arrow_dataset.py | 31 ++++++++++++++++--------------- src/datasets/config.py | 3 +++ 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 4b0c9537a3c..fa9c2daddc9 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1864,12 +1864,11 @@ def _iter_batches(self, batch_size, decoded: bool = True): # Fast iteration format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs) - schema = self.data._schema - for i, batch in enumerate(self.data.to_reader(max_chunksize=batch_size)): - pa_subtable = pa.Table.from_batches([batch], schema=schema) + for batch in self.data.to_reader(max_chunksize=batch_size): + pa_subtable = pa.Table.from_batches([batch]) formatted_output = format_table( pa_subtable, - slice(i, i + batch_size), + range(pa_subtable.num_rows), formatter=formatter, format_columns=self._format_columns, output_all_columns=self._output_all_columns, @@ -1892,17 +1891,19 @@ def _iter(self, decoded: bool = True): # Fast iteration format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs) - schema = self.data._schema - for i, batch in enumerate(self.data.to_reader(max_chunksize=1)): - pa_subtable = pa.Table.from_batches([batch], schema=schema) - formatted_output = format_table( - pa_subtable, - i, - formatter=formatter, - format_columns=self._format_columns, - output_all_columns=self._output_all_columns, - ) - yield formatted_output + batch_size = config.DEFAULT_ITER_BATCH_SIZE + for batch in self.data.to_reader(max_chunksize=batch_size): + for i in range(batch.num_rows): + batch_ex = batch.slice(i, 1) + pa_subtable = pa.Table.from_batches([batch_ex]) + formatted_output = format_table( + pa_subtable, + 0, + formatter=formatter, + format_columns=self._format_columns, + output_all_columns=self._output_all_columns, + ) + yield formatted_output else: for i in range(self.num_rows): yield self._getitem( diff --git a/src/datasets/config.py b/src/datasets/config.py index 2bd5419cbe3..736b84cc882 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -168,6 +168,9 @@ # https://github.com/apache/arrow/blob/master/docs/source/cpp/arrays.rst#size-limitations-and-recommendations) DEFAULT_MAX_BATCH_SIZE = 10_000 +# Number of examples prefetched in `Dataset.__iter__` +DEFAULT_ITER_BATCH_SIZE = 10 + # Pickling tables works only for small tables (<4GiB) # For big tables, we write them on disk instead MAX_TABLE_NBYTES_FOR_PICKLING = 4 << 30 From 6701bc97cc0ec73a81e4a09302cd8245019b9e22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20=C5=A0a=C5=A1ko?= Date: Thu, 29 Sep 2022 16:35:39 +0200 Subject: [PATCH 3/4] Update src/datasets/arrow_dataset.py Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- src/datasets/arrow_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index fa9c2daddc9..bf0bb59cd40 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1854,7 +1854,7 @@ def __len__(self): """ return self.num_rows - def _iter_batches(self, batch_size, decoded: bool = True): + def _iter_batches(self, batch_size: int, decoded: bool = True): """Iterate through the batches of size `batch_size`. If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the From a5387007b81db54adfe29271080c4450b900b39a Mon Sep 17 00:00:00 2001 From: mariosasko Date: Thu, 29 Sep 2022 16:50:29 +0200 Subject: [PATCH 4/4] Address comments --- src/datasets/arrow_dataset.py | 4 +++- src/datasets/config.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index bf0bb59cd40..42479e42091 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1862,6 +1862,7 @@ def _iter_batches(self, batch_size: int, decoded: bool = True): """ if self._indices is None and config.PYARROW_VERSION.major >= 8: # Fast iteration + # Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch) format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs) for batch in self.data.to_reader(max_chunksize=batch_size): @@ -1889,9 +1890,10 @@ def _iter(self, decoded: bool = True): """ if self._indices is None and config.PYARROW_VERSION.major >= 8: # Fast iteration + # Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch) format_kwargs = self._format_kwargs if self._format_kwargs is not None else {} formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs) - batch_size = config.DEFAULT_ITER_BATCH_SIZE + batch_size = config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER for batch in self.data.to_reader(max_chunksize=batch_size): for i in range(batch.num_rows): batch_ex = batch.slice(i, 1) diff --git a/src/datasets/config.py b/src/datasets/config.py index 736b84cc882..8595d5fc581 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -168,8 +168,8 @@ # https://github.com/apache/arrow/blob/master/docs/source/cpp/arrays.rst#size-limitations-and-recommendations) DEFAULT_MAX_BATCH_SIZE = 10_000 -# Number of examples prefetched in `Dataset.__iter__` -DEFAULT_ITER_BATCH_SIZE = 10 +# Size of the preloaded record batch in `Dataset.__iter__` +ARROW_READER_BATCH_SIZE_IN_DATASET_ITER = 10 # Pickling tables works only for small tables (<4GiB) # For big tables, we write them on disk instead