From d60f5ff896b32bdc5cfdd6f91bfb6d1926e09f7a Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 14 Oct 2022 14:11:43 +0200 Subject: [PATCH] Fix filter indices when batched (#5113) * Test filter indices * Fix filter indices when batched * Rename test --- src/datasets/arrow_dataset.py | 2 +- tests/test_arrow_dataset.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index d3db912357b..9bb043b9195 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -2961,7 +2961,7 @@ def init_buffer_and_writer(): else: writer.write(example) else: - for i, batch in enumerate(pbar): + for i, batch in zip(range(0, num_rows, batch_size), pbar): indices = list( range(*(slice(i, i + batch_size).indices(input_dataset.num_rows))) ) # Something simpler? diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 590e7afd191..2b585dd2f89 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3081,6 +3081,12 @@ def test_dataset_add_item_introduce_feature_type(): assert dataset[:] == {"col_1": [None, None, None, "a"]} +def test_dataset_filter_batched_indices(): + ds = Dataset.from_dict({"num": [0, 1, 2, 3]}) + ds = ds.filter(lambda num: num % 2 == 0, input_columns="num", batch_size=2) + assert all(item["num"] % 2 == 0 for item in ds) + + @pytest.mark.parametrize("in_memory", [False, True]) def test_dataset_from_file(in_memory, dataset, arrow_file): filename = arrow_file