diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index eda0b318128..df5ee1aba1a 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1346,13 +1346,23 @@ def load_from_disk(dataset_path: str, fs=None, keep_in_memory: Optional[bool] = split = state["_split"] split = Split(split) if split is not None else split - return Dataset( + dataset = Dataset( arrow_table=arrow_table, info=dataset_info, split=split, fingerprint=state["_fingerprint"], ) + format = { + "type": state["_format_type"], + "format_kwargs": state["_format_kwargs"], + "columns": state["_format_columns"], + "output_all_columns": state["_output_all_columns"], + } + dataset = dataset.with_format(**format) + + return dataset + @property def data(self) -> Table: """The Apache Arrow table backing the dataset. diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 4fe9ea1ea2b..86b8a7b16d2 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -318,6 +318,17 @@ def test_dummy_dataset_load_from_disk(self, in_memory): self.assertEqual(dset[0]["filename"], "my_name-train_0") self.assertEqual(dset["filename"][0], "my_name-train_0") + def test_restore_saved_format(self, in_memory): + with tempfile.TemporaryDirectory() as tmp_dir: + + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + dset.set_format(type="numpy", columns=["col_1"], output_all_columns=True) + dataset_path = os.path.join(tmp_dir, "my_dataset") + dset.save_to_disk(dataset_path) + + with load_from_disk(dataset_path) as loaded_dset: + self.assertEqual(dset.format, loaded_dset.format) + def test_set_format_numpy_multiple_columns(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: