Skip to content

Commit

Permalink
Fast dataset iter
Browse files Browse the repository at this point in the history
  • Loading branch information
mariosasko committed Sep 27, 2022
1 parent 6a91e94 commit 3025024
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 12 deletions.
66 changes: 54 additions & 12 deletions src/datasets/arrow_dataset.py
Expand Up @@ -1854,17 +1854,61 @@ def __len__(self):
"""
return self.num_rows

def _iter_batches(self, batch_size, decoded: bool = True):
"""Iterate through the batches of size `batch_size`.
If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the
selected format.
"""
if self._indices is None and config.PYARROW_VERSION.major >= 8:
# Fast iteration
format_kwargs = self._format_kwargs if self._format_kwargs is not None else {}
formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs)
schema = self.data._schema
for i, batch in enumerate(self.data.to_reader(max_chunksize=batch_size)):
pa_subtable = pa.Table.from_batches([batch], schema=schema)
formatted_output = format_table(
pa_subtable,
slice(i, i + batch_size),
formatter=formatter,
format_columns=self._format_columns,
output_all_columns=self._output_all_columns,
)
yield formatted_output
else:
for i in range(0, self.num_rows, batch_size):
yield self._getitem(
slice(i, i + batch_size),
decoded=decoded,
)

def _iter(self, decoded: bool = True):
"""Iterate through the examples.
If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the
selected format.
"""
for index in range(self.num_rows):
yield self._getitem(
index,
decoded=decoded,
)
if self._indices is None and config.PYARROW_VERSION.major >= 8:
# Fast iteration
format_kwargs = self._format_kwargs if self._format_kwargs is not None else {}
formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs)
schema = self.data._schema
for i, batch in enumerate(self.data.to_reader(max_chunksize=1)):
pa_subtable = pa.Table.from_batches([batch], schema=schema)
formatted_output = format_table(
pa_subtable,
i,
formatter=formatter,
format_columns=self._format_columns,
output_all_columns=self._output_all_columns,
)
yield formatted_output
else:
for i in range(self.num_rows):
yield self._getitem(
i,
decoded=decoded,
)

def __iter__(self):
"""Iterate through the examples.
Expand Down Expand Up @@ -2805,14 +2849,16 @@ def init_buffer_and_writer():

# Loop over single examples or batches and write to buffer/file if examples are to be updated
if not batched:
pbar_iterable = input_dataset._iter(decoded=False)
pbar_total = len(input_dataset)
pbar_iterable = input_dataset._iter(decoded=False)
else:
num_rows = (
len(input_dataset) if not drop_last_batch else len(input_dataset) // batch_size * batch_size
)
pbar_iterable = range(0, num_rows, batch_size)
pbar_total = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size
pbar_iterable = itertools.islice(
input_dataset._iter_batches(batch_size, decoded=False), pbar_total
)
pbar_unit = "ex" if not batched else "ba"
pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc
pbar = logging.tqdm(
Expand All @@ -2835,11 +2881,7 @@ def init_buffer_and_writer():
else:
writer.write(example)
else:
for i in pbar:
batch = input_dataset._getitem(
slice(i, i + batch_size),
decoded=False,
)
for i, batch in enumerate(pbar):
indices = list(
range(*(slice(i, i + batch_size).indices(input_dataset.num_rows)))
) # Something simpler?
Expand Down
24 changes: 24 additions & 0 deletions src/datasets/table.py
Expand Up @@ -330,6 +330,30 @@ def to_pandas(self, *args, **kwargs):
def to_string(self, *args, **kwargs):
return self.table.to_string(*args, **kwargs)

def to_reader(self, *args, **kwargs):
"""
Convert the Table to a RecordBatchReader.
Note that this method is zero-copy, it merely exposes the same data under a different API.
Args:
max_chunksize (:obj:`int`, defaults to :obj:`None`)
Maximum size for RecordBatch chunks. Individual chunks may be smaller depending
on the chunk layout of individual columns.
Returns:
:obj:`pyarrow.RecordBatchReader`
<Tip warning={true}>
pyarrow >= 8.0.0 needs to be installed to use this method.
</Tip>
"""
if config.PYARROW_VERSION.major < 8:
raise NotImplementedError("`pyarrow>=8.0.0` is required to use this method")
return self.table.to_reader(*args, **kwargs)

def field(self, *args, **kwargs):
"""
Select a schema field by its column name or numeric index.
Expand Down

1 comment on commit 3025024

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==6.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.010984 / 0.011353 (-0.000369) 0.005127 / 0.011008 (-0.005882) 0.052404 / 0.038508 (0.013895) 0.037183 / 0.023109 (0.014074) 0.391363 / 0.275898 (0.115465) 0.464045 / 0.323480 (0.140565) 0.006832 / 0.007986 (-0.001154) 0.005287 / 0.004328 (0.000959) 0.010250 / 0.004250 (0.006000) 0.051655 / 0.037052 (0.014603) 0.395595 / 0.258489 (0.137106) 0.457007 / 0.293841 (0.163166) 0.050559 / 0.128546 (-0.077987) 0.015179 / 0.075646 (-0.060468) 0.411768 / 0.419271 (-0.007503) 0.077307 / 0.043533 (0.033774) 0.385026 / 0.255139 (0.129887) 0.420398 / 0.283200 (0.137198) 0.112617 / 0.141683 (-0.029066) 1.849130 / 1.452155 (0.396976) 1.839390 / 1.492716 (0.346673)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.223652 / 0.018006 (0.205646) 0.503631 / 0.000490 (0.503141) 0.008021 / 0.000200 (0.007821) 0.000446 / 0.000054 (0.000391)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.024072 / 0.037411 (-0.013340) 0.111402 / 0.014526 (0.096877) 0.126707 / 0.176557 (-0.049849) 0.176114 / 0.737135 (-0.561021) 0.127278 / 0.296338 (-0.169061)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.607058 / 0.215209 (0.391849) 6.145336 / 2.077655 (4.067681) 2.447749 / 1.504120 (0.943629) 2.049444 / 1.541195 (0.508249) 2.044094 / 1.468490 (0.575604) 0.759570 / 4.584777 (-3.825207) 5.513924 / 3.745712 (1.768211) 5.037092 / 5.269862 (-0.232770) 2.661084 / 4.565676 (-1.904593) 0.090152 / 0.424275 (-0.334123) 0.013326 / 0.007607 (0.005719) 0.822035 / 0.226044 (0.595990) 8.001869 / 2.268929 (5.732941) 2.970272 / 55.444624 (-52.474352) 2.386760 / 6.876477 (-4.489716) 2.327264 / 2.142072 (0.185191) 0.993299 / 4.805227 (-3.811928) 0.197777 / 6.500664 (-6.302887) 0.072563 / 0.075469 (-0.002906)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.827476 / 1.841788 (-0.014312) 14.764338 / 8.074308 (6.690029) 42.868986 / 10.191392 (32.677594) 1.101329 / 0.680424 (0.420905) 0.669927 / 0.534201 (0.135726) 0.462462 / 0.579283 (-0.116821) 0.595085 / 0.434364 (0.160721) 0.329629 / 0.540337 (-0.210709) 0.351292 / 1.386936 (-1.035644)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.006843 / 0.011353 (-0.004510) 0.004269 / 0.011008 (-0.006739) 0.030161 / 0.038508 (-0.008347) 0.031377 / 0.023109 (0.008268) 0.382369 / 0.275898 (0.106471) 0.492130 / 0.323480 (0.168650) 0.004057 / 0.007986 (-0.003928) 0.005204 / 0.004328 (0.000875) 0.005876 / 0.004250 (0.001626) 0.040692 / 0.037052 (0.003639) 0.404539 / 0.258489 (0.146050) 0.482415 / 0.293841 (0.188574) 0.041423 / 0.128546 (-0.087124) 0.012616 / 0.075646 (-0.063031) 0.285252 / 0.419271 (-0.134020) 0.062086 / 0.043533 (0.018553) 0.382572 / 0.255139 (0.127433) 0.413574 / 0.283200 (0.130374) 0.104392 / 0.141683 (-0.037291) 1.573272 / 1.452155 (0.121117) 1.641367 / 1.492716 (0.148651)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.209099 / 0.018006 (0.191093) 0.484051 / 0.000490 (0.483561) 0.001216 / 0.000200 (0.001016) 0.000102 / 0.000054 (0.000048)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.023788 / 0.037411 (-0.013623) 0.098800 / 0.014526 (0.084274) 0.143555 / 0.176557 (-0.033001) 0.155323 / 0.737135 (-0.581813) 0.123990 / 0.296338 (-0.172349)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.782304 / 0.215209 (0.567095) 6.730744 / 2.077655 (4.653090) 2.821089 / 1.504120 (1.316969) 2.425980 / 1.541195 (0.884785) 2.482846 / 1.468490 (1.014355) 0.766043 / 4.584777 (-3.818734) 5.444556 / 3.745712 (1.698844) 5.298215 / 5.269862 (0.028354) 2.816684 / 4.565676 (-1.748993) 0.087080 / 0.424275 (-0.337195) 0.013573 / 0.007607 (0.005966) 0.810292 / 0.226044 (0.584248) 8.109576 / 2.268929 (5.840647) 3.502724 / 55.444624 (-51.941901) 2.823952 / 6.876477 (-4.052525) 2.842505 / 2.142072 (0.700433) 0.993719 / 4.805227 (-3.811509) 0.219012 / 6.500664 (-6.281652) 0.076390 / 0.075469 (0.000921)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.727401 / 1.841788 (-0.114387) 14.821053 / 8.074308 (6.746744) 22.271220 / 10.191392 (12.079828) 1.180765 / 0.680424 (0.500341) 0.670604 / 0.534201 (0.136403) 0.414044 / 0.579283 (-0.165239) 0.543201 / 0.434364 (0.108837) 0.290886 / 0.540337 (-0.249452) 0.295952 / 1.386936 (-1.090984)

CML watermark

Please sign in to comment.