From 3d066ad64f1ca947d5c587d548de6163b8864651 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 7 Sep 2022 15:12:43 +0200 Subject: [PATCH] add tests --- src/datasets/info.py | 40 ++++++------ src/datasets/splits.py | 25 ++++++-- src/datasets/utils/py_utils.py | 2 +- tests/test_info.py | 109 +++++++++++++++++++++++++++++++++ tests/test_splits.py | 36 +++++++++++ 5 files changed, 188 insertions(+), 24 deletions(-) create mode 100644 tests/test_info.py create mode 100644 tests/test_splits.py diff --git a/src/datasets/info.py b/src/datasets/info.py index 87473bd8178..f3f5f7ae40a 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -313,23 +313,24 @@ def copy(self) -> "DatasetInfo": def _to_yaml_dict(self) -> dict: yaml_dict = {} - for field in dataclasses.fields(self): - if field.name in self._INCLUDED_INFO_IN_YAML: - value = getattr(self, field.name) + dataset_info_dict = asdict(self) + for key in dataset_info_dict: + if key in self._INCLUDED_INFO_IN_YAML: + value = getattr(self, key) if hasattr(value, "_to_yaml_list"): # Features, SplitDict - yaml_dict[field.name] = value._to_yaml_list() + yaml_dict[key] = value._to_yaml_list() elif hasattr(value, "_to_yaml_string"): # Version - yaml_dict[field.name] = value._to_yaml_string() + yaml_dict[key] = value._to_yaml_string() else: - yaml_dict[field.name] = value + yaml_dict[key] = value return yaml_dict @classmethod def _from_yaml_dict(cls, yaml_data: dict) -> "DatasetInfo": yaml_data = copy.deepcopy(yaml_data) - if "features" in yaml_data: + if yaml_data.get("features") is not None: yaml_data["features"] = Features._from_yaml_list(yaml_data["features"]) - if "splits" in yaml_data: + if yaml_data.get("splits") is not None: yaml_data["splits"] = SplitDict._from_yaml_list(yaml_data["splits"]) field_names = {f.name for f in dataclasses.fields(cls)} return cls(**{k: v for k, v in yaml_data.items() if k in field_names}) @@ -346,11 +347,10 @@ def write_to_directory(self, dataset_infos_dir, overwrite=False, pretty_print=Fa if os.path.exists(dataset_infos_path): # for backward compatibility, let's update the JSON file if it exists with open(dataset_infos_path, "w", encoding="utf-8") as f: - json.dump( - {config_name: asdict(dset_info) for config_name, dset_info in total_dataset_infos.items()}, - f, - indent=4 if pretty_print else None, - ) + dataset_infos_dict = { + config_name: asdict(dset_info) for config_name, dset_info in total_dataset_infos.items() + } + json.dump(dataset_infos_dict, f, indent=4 if pretty_print else None) # Dump the infos in the YAML part of the README.md file if os.path.exists(dataset_readme_path): dataset_metadata = DatasetMetadata.from_readme(Path(dataset_readme_path)) @@ -365,6 +365,9 @@ def write_to_directory(self, dataset_infos_dir, overwrite=False, pretty_print=Fa dataset_metadata["dataset_infos"] = dataset_metadata["dataset_infos"][0] # no need to include the configuration name when there's only one configuration dataset_metadata["dataset_infos"].pop("config_name", None) + else: + for config_name, dataset_info_yaml_dict in zip(total_dataset_infos, dataset_metadata["dataset_infos"]): + dataset_info_yaml_dict["config_name"] = config_name dataset_metadata.to_readme(Path(dataset_readme_path)) @classmethod @@ -383,7 +386,7 @@ def from_directory(cls, dataset_infos_dir): # Load the info from the YAML part of README.md if os.path.exists(os.path.join(dataset_infos_dir, "README.md")): dataset_metadata = DatasetMetadata.from_readme(Path(dataset_infos_dir) / "README.md") - if isinstance(dataset_metadata.get("dataset_infos"), (list, dict)) and dataset_metadata["dataset_infos"]: + if isinstance(dataset_metadata.get("dataset_infos"), (list, dict)): if isinstance(dataset_metadata["dataset_infos"], list): dataset_infos_dict = { dataset_info_yaml_dict.get("config_name", "default"): DatasetInfo._from_yaml_dict( @@ -392,11 +395,10 @@ def from_directory(cls, dataset_infos_dir): for dataset_info_yaml_dict in dataset_metadata["dataset_infos"] } else: - dataset_infos_dict = { - dataset_metadata["dataset_infos"].get("config_name", "default"): DatasetInfo._from_yaml_dict( - dataset_metadata["dataset_infos"] - ) - } + dataset_info = DatasetInfo._from_yaml_dict(dataset_metadata["dataset_infos"]) + dataset_info.config_name = dataset_metadata["dataset_infos"].get("config_name", "default") + dataset_infos_dict = {dataset_info.config_name: dataset_info} + return cls(**dataset_infos_dict) diff --git a/src/datasets/splits.py b/src/datasets/splits.py index a9538ab7b22..11912f1c951 100644 --- a/src/datasets/splits.py +++ b/src/datasets/splits.py @@ -18,9 +18,10 @@ import abc import collections +import copy import dataclasses import re -from dataclasses import InitVar, dataclass +from dataclasses import dataclass from typing import Dict, List, Optional, Union from .arrow_reader import FileInstructions, make_file_instructions @@ -33,7 +34,14 @@ class SplitInfo: name: str = "" num_bytes: int = 0 num_examples: int = 0 - dataset_name: InitVar[Optional[str]] = None # Pseudo-field: ignored by asdict/fields when converting to/from dict + + # Deprecated + # For backward compatibility, this field needs to always be included in files like + # dataset_infos.json and dataset_info.json files + # To do so, we always include it in the output of datasets.utils.py_utils.asdict(split_info) + dataset_name: Optional[str] = dataclasses.field( + default=None, metadata={"include_in_asdict_even_if_is_default": True} + ) @property def file_instructions(self): @@ -560,13 +568,22 @@ def from_split_dict(cls, split_infos: Union[List, Dict], dataset_name: Optional[ def to_split_dict(self): """Returns a list of SplitInfo protos that we have.""" # Return the SplitInfo, sorted by name - return sorted((s for s in self.values()), key=lambda s: s.name) + out = [] + for split_name, split_info in sorted(self.items()): + split_info = copy.deepcopy(split_info) + split_info.name = split_name + out.append(split_info) + return out def copy(self): return SplitDict.from_split_dict(self.to_split_dict(), self.dataset_name) def _to_yaml_list(self) -> list: - return [asdict(s) for s in self.to_split_dict()] + out = [asdict(s) for s in self.to_split_dict()] + # we don't need the dataset_name attribute that is deprecated + for split_info_dict in out: + split_info_dict.pop("dataset_name", None) + return out @classmethod def _from_yaml_list(cls, yaml_data: list) -> "SplitDict": diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index 8485dc57daa..06f4934149d 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -167,7 +167,7 @@ def _asdict_inner(obj): result = {} for f in fields(obj): value = _asdict_inner(getattr(obj, f.name)) - if value != f.default or not f.init: + if not f.init or value != f.default or f.metadata.get("include_in_asdict_even_if_is_default", False): result[f.name] = value return result elif isinstance(obj, tuple) and hasattr(obj, "_fields"): diff --git a/tests/test_info.py b/tests/test_info.py new file mode 100644 index 00000000000..8100f72a96b --- /dev/null +++ b/tests/test_info.py @@ -0,0 +1,109 @@ +import os + +import pytest +import yaml + +from datasets.features.features import Features, Value +from datasets.info import DatasetInfo, DatasetInfosDict + + +@pytest.mark.parametrize( + "dataset_info", + [ + DatasetInfo(), + DatasetInfo( + description="foo", + features=Features({"a": Value("int32")}), + builder_name="builder", + config_name="config", + version="1.0.0", + splits=[{"name": "train"}], + download_size=42, + ), + ], +) +def test_dataset_info_dump_and_reload(tmp_path, dataset_info: DatasetInfo): + tmp_path = str(tmp_path) + dataset_info.write_to_directory(tmp_path) + reloaded = DatasetInfo.from_directory(tmp_path) + assert dataset_info == reloaded + assert os.path.exists(os.path.join(tmp_path, "dataset_info.json")) + + +def test_dataset_info_to_yaml_dict(): + dataset_info = DatasetInfo( + description="foo", + citation="bar", + homepage="https://foo.bar", + license="CC0", + features=Features({"a": Value("int32")}), + post_processed={}, + supervised_keys=tuple(), + task_templates=[], + builder_name="builder", + config_name="config", + version="1.0.0", + splits=[{"name": "train", "num_examples": 42}], + download_checksums={}, + download_size=1337, + post_processing_size=442, + dataset_size=1234, + size_in_bytes=1337 + 442 + 1234, + ) + dataset_info_yaml_dict = dataset_info._to_yaml_dict() + assert sorted(dataset_info_yaml_dict) == sorted(DatasetInfo._INCLUDED_INFO_IN_YAML) + for key in DatasetInfo._INCLUDED_INFO_IN_YAML: + assert key in dataset_info_yaml_dict + assert isinstance(dataset_info_yaml_dict[key], (list, dict, int, str)) + dataset_info_yaml = yaml.safe_dump(dataset_info_yaml_dict) + reloaded = yaml.safe_load(dataset_info_yaml) + assert dataset_info_yaml_dict == reloaded + + +def test_dataset_info_to_yaml_dict_empty(): + dataset_info = DatasetInfo() + dataset_info_yaml_dict = dataset_info._to_yaml_dict() + assert dataset_info_yaml_dict == {} + + +@pytest.mark.parametrize( + "dataset_infos_dict", + [ + DatasetInfosDict(), + DatasetInfosDict({"default": DatasetInfo()}), + DatasetInfosDict( + { + "default": DatasetInfo( + description="foo", + features=Features({"a": Value("int32")}), + builder_name="builder", + config_name="config", + version="1.0.0", + splits=[{"name": "train"}], + download_size=42, + ) + } + ), + DatasetInfosDict( + { + "v1": DatasetInfo(dataset_size=42), + "v2": DatasetInfo(dataset_size=1337), + } + ), + ], +) +def test_dataset_infos_dict_dump_and_reload(tmp_path, dataset_infos_dict: DatasetInfosDict): + tmp_path = str(tmp_path) + dataset_infos_dict.write_to_directory(tmp_path) + reloaded = DatasetInfosDict.from_directory(tmp_path) + + # the config_name of the dataset_infos_dict take over the attribute + for config_name, dataset_info in dataset_infos_dict.items(): + dataset_info.config_name = config_name + # the yaml representation doesn't include fields like description or citation + # so we just test that we can recover what we can from the yaml + dataset_infos_dict[config_name] = DatasetInfo._from_yaml_dict(dataset_info._to_yaml_dict()) + assert dataset_infos_dict == reloaded + + if dataset_infos_dict: + assert os.path.exists(os.path.join(tmp_path, "README.md")) diff --git a/tests/test_splits.py b/tests/test_splits.py new file mode 100644 index 00000000000..bce980e36ab --- /dev/null +++ b/tests/test_splits.py @@ -0,0 +1,36 @@ +import pytest + +from datasets.splits import SplitDict, SplitInfo +from datasets.utils.py_utils import asdict + + +@pytest.mark.parametrize( + "split_dict", + [ + SplitDict(), + SplitDict({"train": SplitInfo(name="train", num_bytes=1337, num_examples=42, dataset_name="my_dataset")}), + SplitDict({"train": SplitInfo(name="train", num_bytes=1337, num_examples=42)}), + SplitDict({"train": SplitInfo()}), + ], +) +def test_split_dict_to_yaml_list(split_dict: SplitDict): + split_dict_yaml_list = split_dict._to_yaml_list() + assert len(split_dict_yaml_list) == len(split_dict) + reloaded = SplitDict._from_yaml_list(split_dict_yaml_list) + for split_name, split_info in split_dict.items(): + # dataset_name field is deprecated, and is therefore not part of the YAML dump + split_info.dataset_name = None + # the split name of split_dict takes over the name of the split info object + split_info.name = split_name + assert split_dict == reloaded + + +@pytest.mark.parametrize( + "split_info", [SplitInfo(), SplitInfo(dataset_name=None), SplitInfo(dataset_name="my_dataset")] +) +def test_split_dict_asdict_has_dataset_name(split_info): + # For backward compatibility, we need asdict(split_dict) to return split info dictrionaries with the "dataset_name" + # field even if it's deprecated. This way old versionso of `datasets` can still reload dataset_infos.json files + split_dict_asdict = asdict(SplitDict({"train": split_info})) + assert "dataset_name" in split_dict_asdict["train"] + assert split_dict_asdict["train"]["dataset_name"] == split_info.dataset_name