Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
188 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import os | ||
|
||
import pytest | ||
import yaml | ||
|
||
from datasets.features.features import Features, Value | ||
from datasets.info import DatasetInfo, DatasetInfosDict | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"dataset_info", | ||
[ | ||
DatasetInfo(), | ||
DatasetInfo( | ||
description="foo", | ||
features=Features({"a": Value("int32")}), | ||
builder_name="builder", | ||
config_name="config", | ||
version="1.0.0", | ||
splits=[{"name": "train"}], | ||
download_size=42, | ||
), | ||
], | ||
) | ||
def test_dataset_info_dump_and_reload(tmp_path, dataset_info: DatasetInfo): | ||
tmp_path = str(tmp_path) | ||
dataset_info.write_to_directory(tmp_path) | ||
reloaded = DatasetInfo.from_directory(tmp_path) | ||
assert dataset_info == reloaded | ||
assert os.path.exists(os.path.join(tmp_path, "dataset_info.json")) | ||
|
||
|
||
def test_dataset_info_to_yaml_dict(): | ||
dataset_info = DatasetInfo( | ||
description="foo", | ||
citation="bar", | ||
homepage="https://foo.bar", | ||
license="CC0", | ||
features=Features({"a": Value("int32")}), | ||
post_processed={}, | ||
supervised_keys=tuple(), | ||
task_templates=[], | ||
builder_name="builder", | ||
config_name="config", | ||
version="1.0.0", | ||
splits=[{"name": "train", "num_examples": 42}], | ||
download_checksums={}, | ||
download_size=1337, | ||
post_processing_size=442, | ||
dataset_size=1234, | ||
size_in_bytes=1337 + 442 + 1234, | ||
) | ||
dataset_info_yaml_dict = dataset_info._to_yaml_dict() | ||
assert sorted(dataset_info_yaml_dict) == sorted(DatasetInfo._INCLUDED_INFO_IN_YAML) | ||
for key in DatasetInfo._INCLUDED_INFO_IN_YAML: | ||
assert key in dataset_info_yaml_dict | ||
assert isinstance(dataset_info_yaml_dict[key], (list, dict, int, str)) | ||
dataset_info_yaml = yaml.safe_dump(dataset_info_yaml_dict) | ||
reloaded = yaml.safe_load(dataset_info_yaml) | ||
assert dataset_info_yaml_dict == reloaded | ||
|
||
|
||
def test_dataset_info_to_yaml_dict_empty(): | ||
dataset_info = DatasetInfo() | ||
dataset_info_yaml_dict = dataset_info._to_yaml_dict() | ||
assert dataset_info_yaml_dict == {} | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"dataset_infos_dict", | ||
[ | ||
DatasetInfosDict(), | ||
DatasetInfosDict({"default": DatasetInfo()}), | ||
DatasetInfosDict( | ||
{ | ||
"default": DatasetInfo( | ||
description="foo", | ||
features=Features({"a": Value("int32")}), | ||
builder_name="builder", | ||
config_name="config", | ||
version="1.0.0", | ||
splits=[{"name": "train"}], | ||
download_size=42, | ||
) | ||
} | ||
), | ||
DatasetInfosDict( | ||
{ | ||
"v1": DatasetInfo(dataset_size=42), | ||
"v2": DatasetInfo(dataset_size=1337), | ||
} | ||
), | ||
], | ||
) | ||
def test_dataset_infos_dict_dump_and_reload(tmp_path, dataset_infos_dict: DatasetInfosDict): | ||
tmp_path = str(tmp_path) | ||
dataset_infos_dict.write_to_directory(tmp_path) | ||
reloaded = DatasetInfosDict.from_directory(tmp_path) | ||
|
||
# the config_name of the dataset_infos_dict take over the attribute | ||
for config_name, dataset_info in dataset_infos_dict.items(): | ||
dataset_info.config_name = config_name | ||
# the yaml representation doesn't include fields like description or citation | ||
# so we just test that we can recover what we can from the yaml | ||
dataset_infos_dict[config_name] = DatasetInfo._from_yaml_dict(dataset_info._to_yaml_dict()) | ||
assert dataset_infos_dict == reloaded | ||
|
||
if dataset_infos_dict: | ||
assert os.path.exists(os.path.join(tmp_path, "README.md")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import pytest | ||
|
||
from datasets.splits import SplitDict, SplitInfo | ||
from datasets.utils.py_utils import asdict | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"split_dict", | ||
[ | ||
SplitDict(), | ||
SplitDict({"train": SplitInfo(name="train", num_bytes=1337, num_examples=42, dataset_name="my_dataset")}), | ||
SplitDict({"train": SplitInfo(name="train", num_bytes=1337, num_examples=42)}), | ||
SplitDict({"train": SplitInfo()}), | ||
], | ||
) | ||
def test_split_dict_to_yaml_list(split_dict: SplitDict): | ||
split_dict_yaml_list = split_dict._to_yaml_list() | ||
assert len(split_dict_yaml_list) == len(split_dict) | ||
reloaded = SplitDict._from_yaml_list(split_dict_yaml_list) | ||
for split_name, split_info in split_dict.items(): | ||
# dataset_name field is deprecated, and is therefore not part of the YAML dump | ||
split_info.dataset_name = None | ||
# the split name of split_dict takes over the name of the split info object | ||
split_info.name = split_name | ||
assert split_dict == reloaded | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"split_info", [SplitInfo(), SplitInfo(dataset_name=None), SplitInfo(dataset_name="my_dataset")] | ||
) | ||
def test_split_dict_asdict_has_dataset_name(split_info): | ||
# For backward compatibility, we need asdict(split_dict) to return split info dictrionaries with the "dataset_name" | ||
# field even if it's deprecated. This way old versionso of `datasets` can still reload dataset_infos.json files | ||
split_dict_asdict = asdict(SplitDict({"train": split_info})) | ||
assert "dataset_name" in split_dict_asdict["train"] | ||
assert split_dict_asdict["train"]["dataset_name"] == split_info.dataset_name |
3d066ad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Show benchmarks
PyArrow==6.0.0
Show updated benchmarks!
Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
Show updated benchmarks!
Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json