From c13453d0af6633c1a9850cde0914a1e50a8c2c0e Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 14 Oct 2022 13:24:19 +0200 Subject: [PATCH 1/3] Test filter indices --- tests/test_arrow_dataset.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 590e7afd191..d37ab76c360 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3081,6 +3081,12 @@ def test_dataset_add_item_introduce_feature_type(): assert dataset[:] == {"col_1": [None, None, None, "a"]} +def test_dataset_filter_indices(): + ds = Dataset.from_dict({"num": [0, 1, 2, 3]}) + ds = ds.filter(lambda num: num % 2 == 0, input_columns="num", batch_size=2) + assert all(item["num"] % 2 == 0 for item in ds) + + @pytest.mark.parametrize("in_memory", [False, True]) def test_dataset_from_file(in_memory, dataset, arrow_file): filename = arrow_file From 17feaeee0e754df10d1deb7679b9902cb030019d Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 14 Oct 2022 13:28:26 +0200 Subject: [PATCH 2/3] Fix filter indices when batched --- src/datasets/arrow_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index a0b8481825d..c7ab133e88d 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -2961,7 +2961,7 @@ def init_buffer_and_writer(): else: writer.write(example) else: - for i, batch in enumerate(pbar): + for i, batch in zip(range(0, num_rows, batch_size), pbar): indices = list( range(*(slice(i, i + batch_size).indices(input_dataset.num_rows))) ) # Something simpler? From 391dd869b88ec9b0588f3bd85a0ecdfaa3255847 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 14 Oct 2022 13:48:31 +0200 Subject: [PATCH 3/3] Rename test --- tests/test_arrow_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index d37ab76c360..2b585dd2f89 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3081,7 +3081,7 @@ def test_dataset_add_item_introduce_feature_type(): assert dataset[:] == {"col_1": [None, None, None, "a"]} -def test_dataset_filter_indices(): +def test_dataset_filter_batched_indices(): ds = Dataset.from_dict({"num": [0, 1, 2, 3]}) ds = ds.filter(lambda num: num % 2 == 0, input_columns="num", batch_size=2) assert all(item["num"] % 2 == 0 for item in ds)