Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix iter_batches #5115

Merged
merged 11 commits into from Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 11 additions & 12 deletions src/datasets/arrow_dataset.py
Expand Up @@ -97,6 +97,7 @@
embed_table_storage,
list_table_cache_files,
table_cast,
table_iter,
table_visitor,
)
from .tasks import TaskTemplate
Expand Down Expand Up @@ -1924,7 +1925,7 @@ def __len__(self):
"""
return self.num_rows

def _iter_batches(self, batch_size: int, decoded: bool = True):
def _iter_batches(self, batch_size: int, decoded: bool = True, drop_last_batch: bool = False):
"""Iterate through the batches of size `batch_size`.

If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the
Expand All @@ -1935,16 +1936,15 @@ def _iter_batches(self, batch_size: int, decoded: bool = True):
# Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch)
format_kwargs = self._format_kwargs if self._format_kwargs is not None else {}
formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs)
for batch in self.data.to_reader(max_chunksize=batch_size):
pa_subtable = pa.Table.from_batches([batch])
formatted_output = format_table(
for pa_subtable in table_iter(self.data, batch_size=batch_size, drop_last_batch=drop_last_batch):
formatted_batch = format_table(
pa_subtable,
range(pa_subtable.num_rows),
formatter=formatter,
format_columns=self._format_columns,
output_all_columns=self._output_all_columns,
)
yield formatted_output
yield formatted_batch
else:
for i in range(0, self.num_rows, batch_size):
yield self._getitem(
Expand All @@ -1964,12 +1964,11 @@ def _iter(self, decoded: bool = True):
format_kwargs = self._format_kwargs if self._format_kwargs is not None else {}
formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs)
batch_size = config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER
for batch in self.data.to_reader(max_chunksize=batch_size):
for i in range(batch.num_rows):
batch_ex = batch.slice(i, 1)
pa_subtable = pa.Table.from_batches([batch_ex])
for pa_subtable in table_iter(self.data, batch_size=batch_size):
for i in range(pa_subtable.num_rows):
pa_subtable_ex = pa_subtable.slice(i, 1)
formatted_output = format_table(
pa_subtable,
pa_subtable_ex,
0,
formatter=formatter,
format_columns=self._format_columns,
Expand Down Expand Up @@ -2936,8 +2935,8 @@ def init_buffer_and_writer():
len(input_dataset) if not drop_last_batch else len(input_dataset) // batch_size * batch_size
)
pbar_total = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size
pbar_iterable = itertools.islice(
input_dataset._iter_batches(batch_size, decoded=False), pbar_total
pbar_iterable = input_dataset._iter_batches(
batch_size, decoded=False, drop_last_batch=drop_last_batch
)
pbar_unit = "ex" if not batched else "ba"
pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc
Expand Down
36 changes: 36 additions & 0 deletions src/datasets/table.py
Expand Up @@ -2148,3 +2148,39 @@ def _visit(array, feature):

for name, feature in features.items():
_visit(table[name], feature)


def table_iter(pa_table: pa.Table, batch_size: int, drop_last_batch=False):
"""Iterate ober sub-tables of size `batch_size`.
lhoestq marked this conversation as resolved.
Show resolved Hide resolved

lhoestq marked this conversation as resolved.
Show resolved Hide resolved
Requires pyarrow>=8.0.0

Args:
table (:obj:`pyarrow.Table`): PyArrow table to iterate over
batch_size (:obj:`int`): size of each sub-table to yield
drop_last_batch (:obj:`bool`, default `False`): Drop the last batch if it is smaller than `batch_size`
"""
if config.PYARROW_VERSION.major < 8:
raise RuntimeError(f"pyarrow>=8.0.0 is needed to use table_iter but you have {config.PYARROW_VERSION}")
chunks_buffer = []
chunks_buffer_size = 0
for chunk in pa_table.to_reader(max_chunksize=batch_size):
if len(chunk) == 0:
continue
elif chunks_buffer_size + len(chunk) < batch_size:
chunks_buffer.append(chunk)
chunks_buffer_size += len(chunk)
continue
elif chunks_buffer_size + len(chunk) == batch_size:
chunks_buffer.append(chunk)
yield pa.Table.from_batches(chunks_buffer)
chunks_buffer = []
chunks_buffer_size = 0
else:
cropped_chunk_length = batch_size - chunks_buffer_size
chunks_buffer.append(chunk.slice(0, cropped_chunk_length))
yield pa.Table.from_batches(chunks_buffer)
chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)]
chunks_buffer_size = len(chunk) - cropped_chunk_length
Comment on lines +2166 to +2184
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can remove the variable chunks_buffer_size:

Suggested change
chunks_buffer_size = 0
for chunk in pa_table.to_reader(max_chunksize=batch_size):
if len(chunk) == 0:
continue
elif chunks_buffer_size + len(chunk) < batch_size:
chunks_buffer.append(chunk)
chunks_buffer_size += len(chunk)
continue
elif chunks_buffer_size + len(chunk) == batch_size:
chunks_buffer.append(chunk)
yield pa.Table.from_batches(chunks_buffer)
chunks_buffer = []
chunks_buffer_size = 0
else:
cropped_chunk_length = batch_size - chunks_buffer_size
chunks_buffer.append(chunk.slice(0, cropped_chunk_length))
yield pa.Table.from_batches(chunks_buffer)
chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)]
chunks_buffer_size = len(chunk) - cropped_chunk_length
for chunk in pa_table.to_reader(max_chunksize=batch_size):
if len(chunk) == 0:
continue
elif len(chunks_buffer) + len(chunk) < batch_size:
chunks_buffer.append(chunk)
continue
elif len(chunks_buffer) + len(chunk) == batch_size:
chunks_buffer.append(chunk)
yield pa.Table.from_batches(chunks_buffer)
chunks_buffer = []
else:
cropped_chunk_length = batch_size - len(chunks_buffer)
chunks_buffer.append(chunk.slice(0, cropped_chunk_length))
yield pa.Table.from_batches(chunks_buffer)
chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)]

Copy link
Member Author

@lhoestq lhoestq Oct 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chunks_buffer_size is the sum of the lengths of all the chunks in the buffer - not just the length of the buffer

if not drop_last_batch and chunks_buffer:
yield pa.Table.from_batches(chunks_buffer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just wondering if this function may have a performance impact, instead of just calling for batch in self.data.to_reader(max_chunksize=batch_size), as before.

If so, we should check how much impact, so that we do not lose the performance gain introduced by #5030.

Copy link
Member Author

@lhoestq lhoestq Oct 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is roughly the same as in #5030

Also note that the worst case scenario for this implementation is when the dataset is made of chunks of length 1, but even in this case this is faster than calling __getitem__ for each item.

ds = concatenate_datasets([Dataset.from_dict({"a": [i]}) for i in range(100)])
%time list(ds._iter_batches(batch_size=10))
# <1ms
%time [ds[i:i+10] for i in range(0, len(ds), 10)]
# 1ms
%time list(ds)
# 3ms
%time [ds[i] for i in range(len(ds))]
# 5ms

It's even better for big datasets, since __getitem__ is not O(1) because of interpolation search. Here getting the next item is O(1)

28 changes: 28 additions & 0 deletions tests/test_table.py
Expand Up @@ -6,6 +6,7 @@
import pyarrow as pa
import pytest

import datasets
from datasets import Sequence, Value
from datasets.features.features import ClassLabel, Features, Image
from datasets.table import (
Expand All @@ -24,6 +25,7 @@
embed_table_storage,
inject_arrow_table_documentation,
table_cast,
table_iter,
)

from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, slow
Expand Down Expand Up @@ -1118,3 +1120,29 @@ def test_embed_table_storage(image_file):
embedded_images_table = embed_table_storage(table)
assert embedded_images_table.to_pydict()["image"][0]["path"] is None
assert isinstance(embedded_images_table.to_pydict()["image"][0]["bytes"], bytes)


@pytest.mark.skipif(datasets.config.PYARROW_VERSION.major < 8, reason="only available on pyarrow>=8")
@pytest.mark.parametrize(
"pa_table",
[
pa.table({"foo": range(10)}),
pa.concat_tables([pa.table({"foo": range(0, 5)}), pa.table({"foo": range(5, 10)})]),
pa.concat_tables([pa.table({"foo": [i]}) for i in range(10)]),
],
)
@pytest.mark.parametrize("batch_size", [1, 2, 3, 9, 10, 11, 20])
@pytest.mark.parametrize("drop_last_batch", [False, True])
def test_table_iter(pa_table, batch_size, drop_last_batch):
num_rows = len(pa_table) if not drop_last_batch else len(pa_table) // batch_size * batch_size
num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size
subtables = list(table_iter(pa_table, batch_size=batch_size, drop_last_batch=drop_last_batch))
assert len(subtables) == num_batches
if drop_last_batch:
assert all(len(subtable) == batch_size for subtable in subtables)
else:
assert all(len(subtable) == batch_size for subtable in subtables[:-1])
assert len(subtables[-1]) <= batch_size
if num_rows > 0:
reloaded = pa.concat_tables(subtables)
assert pa_table.slice(0, num_rows).to_pydict() == reloaded.to_pydict()