Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix iter_batches #5115

Merged
merged 11 commits into from Oct 14, 2022
Merged

Fix iter_batches #5115

merged 11 commits into from Oct 14, 2022

Conversation

lhoestq
Copy link
Member

@lhoestq lhoestq commented Oct 14, 2022

The pa.Table.to_reader() method available in pyarrow>=8.0.0 may return chunks of size < max_chunksize, therefore iter_batches can return batches smaller than the batch_size specified by the user

Therefore batched map couldn't always use batches of the right size, e.g. this fails because it runs only on one batch of one element:

from datasets import Dataset, concatenate_datasets

ds = concatenate_datasets([Dataset.from_dict({"a": [i]}) for i in range(10)])

ds2 = ds.map(lambda _: {}, batched=True)
assert list(ds2) == list(ds)

This was introduced in #5030

Close #5111

This will require a patch release along with #5113

TODO:

  • fix tests
  • add more tests

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 14, 2022

The documentation is not available anymore as the PR was closed or merged.

@lhoestq lhoestq marked this pull request as ready for review October 14, 2022 13:39
@lhoestq
Copy link
Member Author

lhoestq commented Oct 14, 2022

I also ran the code in #5111 and it works fine now :)

@lhoestq
Copy link
Member Author

lhoestq commented Oct 14, 2022

This is ready for review :)

Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix.

Just a few comments below.

src/datasets/table.py Show resolved Hide resolved
Comment on lines +2162 to +2180
chunks_buffer_size = 0
for chunk in pa_table.to_reader(max_chunksize=batch_size):
if len(chunk) == 0:
continue
elif chunks_buffer_size + len(chunk) < batch_size:
chunks_buffer.append(chunk)
chunks_buffer_size += len(chunk)
continue
elif chunks_buffer_size + len(chunk) == batch_size:
chunks_buffer.append(chunk)
yield pa.Table.from_batches(chunks_buffer)
chunks_buffer = []
chunks_buffer_size = 0
else:
cropped_chunk_length = batch_size - chunks_buffer_size
chunks_buffer.append(chunk.slice(0, cropped_chunk_length))
yield pa.Table.from_batches(chunks_buffer)
chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)]
chunks_buffer_size = len(chunk) - cropped_chunk_length
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can remove the variable chunks_buffer_size:

Suggested change
chunks_buffer_size = 0
for chunk in pa_table.to_reader(max_chunksize=batch_size):
if len(chunk) == 0:
continue
elif chunks_buffer_size + len(chunk) < batch_size:
chunks_buffer.append(chunk)
chunks_buffer_size += len(chunk)
continue
elif chunks_buffer_size + len(chunk) == batch_size:
chunks_buffer.append(chunk)
yield pa.Table.from_batches(chunks_buffer)
chunks_buffer = []
chunks_buffer_size = 0
else:
cropped_chunk_length = batch_size - chunks_buffer_size
chunks_buffer.append(chunk.slice(0, cropped_chunk_length))
yield pa.Table.from_batches(chunks_buffer)
chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)]
chunks_buffer_size = len(chunk) - cropped_chunk_length
for chunk in pa_table.to_reader(max_chunksize=batch_size):
if len(chunk) == 0:
continue
elif len(chunks_buffer) + len(chunk) < batch_size:
chunks_buffer.append(chunk)
continue
elif len(chunks_buffer) + len(chunk) == batch_size:
chunks_buffer.append(chunk)
yield pa.Table.from_batches(chunks_buffer)
chunks_buffer = []
else:
cropped_chunk_length = batch_size - len(chunks_buffer)
chunks_buffer.append(chunk.slice(0, cropped_chunk_length))
yield pa.Table.from_batches(chunks_buffer)
chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)]

Copy link
Member Author

@lhoestq lhoestq Oct 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chunks_buffer_size is the sum of the lengths of all the chunks in the buffer - not just the length of the buffer

chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)]
chunks_buffer_size = len(chunk) - cropped_chunk_length
if not drop_last_batch and chunks_buffer:
yield pa.Table.from_batches(chunks_buffer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just wondering if this function may have a performance impact, instead of just calling for batch in self.data.to_reader(max_chunksize=batch_size), as before.

If so, we should check how much impact, so that we do not lose the performance gain introduced by #5030.

Copy link
Member Author

@lhoestq lhoestq Oct 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is roughly the same as in #5030

Also note that the worst case scenario for this implementation is when the dataset is made of chunks of length 1, but even in this case this is faster than calling __getitem__ for each item.

ds = concatenate_datasets([Dataset.from_dict({"a": [i]}) for i in range(100)])
%time list(ds._iter_batches(batch_size=10))
# <1ms
%time [ds[i:i+10] for i in range(0, len(ds), 10)]
# 1ms
%time list(ds)
# 3ms
%time [ds[i] for i in range(len(ds))]
# 5ms

It's even better for big datasets, since __getitem__ is not O(1) because of interpolation search. Here getting the next item is O(1)

Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Copy link
Contributor

@mariosasko mariosasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM!

@lhoestq lhoestq merged commit eadc79a into main Oct 14, 2022
@lhoestq lhoestq deleted the fix-iter_batches branch October 14, 2022 14:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

map and filter not working properly in multiprocessing with the new release 2.6.0
4 participants