From 5611581fdcdd3c9bbe15b673d21f9b54ea01c982 Mon Sep 17 00:00:00 2001 From: Sofia Oliveira <74454835+asofiaoliveira@users.noreply.github.com> Date: Wed, 5 Oct 2022 14:39:46 +0100 Subject: [PATCH 1/3] Set saved format in load_from_disk dataset --- src/datasets/arrow_dataset.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index eda0b318128..ed6d3cf8393 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1346,12 +1346,17 @@ def load_from_disk(dataset_path: str, fs=None, keep_in_memory: Optional[bool] = split = state["_split"] split = Split(split) if split is not None else split - return Dataset( + dataset = Dataset( arrow_table=arrow_table, info=dataset_info, split=split, fingerprint=state["_fingerprint"], ) + dataset.set_format( + state["_format_type"], state["_format_columns"], state["_output_all_columns"], **state["_format_kwargs"] + ) + + return dataset @property def data(self) -> Table: From 46ddb30add495cfc79b3fc8cf3f2b17c23ec3a10 Mon Sep 17 00:00:00 2001 From: Sofia Oliveira <74454835+asofiaoliveira@users.noreply.github.com> Date: Sat, 8 Oct 2022 21:09:36 +0100 Subject: [PATCH 2/3] Add test --- tests/test_arrow_dataset.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 4fe9ea1ea2b..86b8a7b16d2 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -318,6 +318,17 @@ def test_dummy_dataset_load_from_disk(self, in_memory): self.assertEqual(dset[0]["filename"], "my_name-train_0") self.assertEqual(dset["filename"][0], "my_name-train_0") + def test_restore_saved_format(self, in_memory): + with tempfile.TemporaryDirectory() as tmp_dir: + + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + dset.set_format(type="numpy", columns=["col_1"], output_all_columns=True) + dataset_path = os.path.join(tmp_dir, "my_dataset") + dset.save_to_disk(dataset_path) + + with load_from_disk(dataset_path) as loaded_dset: + self.assertEqual(dset.format, loaded_dset.format) + def test_set_format_numpy_multiple_columns(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: From 43c1836871c935e483dca1cb12aa2f5917a2ef02 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Mon, 10 Oct 2022 19:55:05 +0200 Subject: [PATCH 3/3] Use with_format --- src/datasets/arrow_dataset.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index ed6d3cf8393..df5ee1aba1a 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1352,9 +1352,14 @@ def load_from_disk(dataset_path: str, fs=None, keep_in_memory: Optional[bool] = split=split, fingerprint=state["_fingerprint"], ) - dataset.set_format( - state["_format_type"], state["_format_columns"], state["_output_all_columns"], **state["_format_kwargs"] - ) + + format = { + "type": state["_format_type"], + "format_kwargs": state["_format_kwargs"], + "columns": state["_format_columns"], + "output_all_columns": state["_output_all_columns"], + } + dataset = dataset.with_format(**format) return dataset