Skip to content

Commit

Permalink
fix: avoid casting tuples after Dataset.map (#4993)
Browse files Browse the repository at this point in the history
* fix: avoid casting tuples after Dataset.map

* fix: fix test_cast_to_python_objects_tuple test
  • Loading branch information
szmoro committed Sep 20, 2022
1 parent 1b4c3cb commit 8ba0522
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,12 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_cas
for elmt in obj
], True
else:
if isinstance(obj, list):
if isinstance(obj, (list, tuple)):
return obj, False
else:
return list(obj), True
else:
return obj if isinstance(obj, list) else [], isinstance(obj, tuple)
return obj, False
else:
return obj, False

Expand Down
2 changes: 1 addition & 1 deletion tests/features/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def test_cast_to_python_objects_list(self):

def test_cast_to_python_objects_tuple(self):
obj = {"col_1": [{"vec": (1, 2, 3), "txt": "foo"}] * 3, "col_2": [(1, 2), (3, 4), (5, 6)]}
expected_obj = {"col_1": [{"vec": [1, 2, 3], "txt": "foo"}] * 3, "col_2": [[1, 2], [3, 4], [5, 6]]}
expected_obj = {"col_1": [{"vec": (1, 2, 3), "txt": "foo"}] * 3, "col_2": [(1, 2), (3, 4), (5, 6)]}
casted_obj = cast_to_python_objects(obj)
self.assertDictEqual(casted_obj, expected_obj)

Expand Down

1 comment on commit 8ba0522

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==6.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.009444 / 0.011353 (-0.001909) 0.004516 / 0.011008 (-0.006492) 0.035511 / 0.038508 (-0.002997) 0.042071 / 0.023109 (0.018962) 0.351684 / 0.275898 (0.075786) 0.421962 / 0.323480 (0.098482) 0.007046 / 0.007986 (-0.000940) 0.004000 / 0.004328 (-0.000328) 0.008191 / 0.004250 (0.003940) 0.060554 / 0.037052 (0.023502) 0.364435 / 0.258489 (0.105946) 0.412708 / 0.293841 (0.118867) 0.041938 / 0.128546 (-0.086608) 0.011183 / 0.075646 (-0.064463) 0.309273 / 0.419271 (-0.109998) 0.061895 / 0.043533 (0.018362) 0.351378 / 0.255139 (0.096239) 0.382368 / 0.283200 (0.099168) 0.127121 / 0.141683 (-0.014562) 1.783684 / 1.452155 (0.331529) 1.900705 / 1.492716 (0.407988)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.245255 / 0.018006 (0.227249) 0.537100 / 0.000490 (0.536610) 0.002094 / 0.000200 (0.001894) 0.000151 / 0.000054 (0.000097)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.028217 / 0.037411 (-0.009194) 0.122076 / 0.014526 (0.107550) 0.134081 / 0.176557 (-0.042476) 0.181842 / 0.737135 (-0.555293) 0.140482 / 0.296338 (-0.155857)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.468833 / 0.215209 (0.253624) 4.701532 / 2.077655 (2.623877) 2.096427 / 1.504120 (0.592307) 1.881125 / 1.541195 (0.339930) 1.943935 / 1.468490 (0.475445) 0.504459 / 4.584777 (-4.080318) 4.934575 / 3.745712 (1.188863) 4.320272 / 5.269862 (-0.949590) 2.214783 / 4.565676 (-2.350893) 0.060946 / 0.424275 (-0.363329) 0.013071 / 0.007607 (0.005464) 0.594164 / 0.226044 (0.368120) 5.934088 / 2.268929 (3.665160) 2.633149 / 55.444624 (-52.811476) 2.269314 / 6.876477 (-4.607163) 2.419241 / 2.142072 (0.277169) 0.636695 / 4.805227 (-4.168533) 0.141855 / 6.500664 (-6.358809) 0.072495 / 0.075469 (-0.002974)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.729927 / 1.841788 (-0.111861) 16.263581 / 8.074308 (8.189273) 29.537442 / 10.191392 (19.346050) 1.034746 / 0.680424 (0.354322) 0.662094 / 0.534201 (0.127893) 0.457363 / 0.579283 (-0.121920) 0.498503 / 0.434364 (0.064140) 0.328407 / 0.540337 (-0.211930) 0.322905 / 1.386936 (-1.064031)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.007307 / 0.011353 (-0.004046) 0.004517 / 0.011008 (-0.006491) 0.033247 / 0.038508 (-0.005261) 0.041492 / 0.023109 (0.018382) 0.406871 / 0.275898 (0.130973) 0.492768 / 0.323480 (0.169288) 0.004621 / 0.007986 (-0.003364) 0.005572 / 0.004328 (0.001243) 0.005694 / 0.004250 (0.001444) 0.053022 / 0.037052 (0.015969) 0.414620 / 0.258489 (0.156131) 0.464400 / 0.293841 (0.170559) 0.035553 / 0.128546 (-0.092993) 0.011297 / 0.075646 (-0.064350) 0.306996 / 0.419271 (-0.112276) 0.081770 / 0.043533 (0.038237) 0.405583 / 0.255139 (0.150444) 0.437417 / 0.283200 (0.154217) 0.130263 / 0.141683 (-0.011420) 1.742250 / 1.452155 (0.290095) 1.793897 / 1.492716 (0.301181)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.237790 / 0.018006 (0.219784) 0.496673 / 0.000490 (0.496183) 0.005013 / 0.000200 (0.004813) 0.000161 / 0.000054 (0.000107)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.027287 / 0.037411 (-0.010124) 0.124941 / 0.014526 (0.110415) 0.135558 / 0.176557 (-0.040999) 0.197394 / 0.737135 (-0.539742) 0.140681 / 0.296338 (-0.155658)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.501387 / 0.215209 (0.286178) 5.011001 / 2.077655 (2.933346) 2.441348 / 1.504120 (0.937228) 2.203006 / 1.541195 (0.661812) 2.267925 / 1.468490 (0.799435) 0.506786 / 4.584777 (-4.077991) 4.781957 / 3.745712 (1.036245) 4.656309 / 5.269862 (-0.613553) 2.267645 / 4.565676 (-2.298031) 0.062347 / 0.424275 (-0.361928) 0.013374 / 0.007607 (0.005767) 0.616848 / 0.226044 (0.390804) 6.170259 / 2.268929 (3.901330) 2.991840 / 55.444624 (-52.452784) 2.599533 / 6.876477 (-4.276944) 2.720901 / 2.142072 (0.578829) 0.632663 / 4.805227 (-4.172564) 0.141917 / 6.500664 (-6.358747) 0.073840 / 0.075469 (-0.001629)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.922299 / 1.841788 (0.080511) 16.626805 / 8.074308 (8.552497) 29.707011 / 10.191392 (19.515619) 1.176637 / 0.680424 (0.496213) 0.770270 / 0.534201 (0.236070) 0.491311 / 0.579283 (-0.087972) 0.546464 / 0.434364 (0.112100) 0.318545 / 0.540337 (-0.221793) 0.325528 / 1.386936 (-1.061408)

CML watermark

Please sign in to comment.