Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Sep 7, 2022
1 parent c52f40f commit 3d066ad
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 24 deletions.
40 changes: 21 additions & 19 deletions src/datasets/info.py
Expand Up @@ -313,23 +313,24 @@ def copy(self) -> "DatasetInfo":

def _to_yaml_dict(self) -> dict:
yaml_dict = {}
for field in dataclasses.fields(self):
if field.name in self._INCLUDED_INFO_IN_YAML:
value = getattr(self, field.name)
dataset_info_dict = asdict(self)
for key in dataset_info_dict:
if key in self._INCLUDED_INFO_IN_YAML:
value = getattr(self, key)
if hasattr(value, "_to_yaml_list"): # Features, SplitDict
yaml_dict[field.name] = value._to_yaml_list()
yaml_dict[key] = value._to_yaml_list()
elif hasattr(value, "_to_yaml_string"): # Version
yaml_dict[field.name] = value._to_yaml_string()
yaml_dict[key] = value._to_yaml_string()
else:
yaml_dict[field.name] = value
yaml_dict[key] = value
return yaml_dict

@classmethod
def _from_yaml_dict(cls, yaml_data: dict) -> "DatasetInfo":
yaml_data = copy.deepcopy(yaml_data)
if "features" in yaml_data:
if yaml_data.get("features") is not None:
yaml_data["features"] = Features._from_yaml_list(yaml_data["features"])
if "splits" in yaml_data:
if yaml_data.get("splits") is not None:
yaml_data["splits"] = SplitDict._from_yaml_list(yaml_data["splits"])
field_names = {f.name for f in dataclasses.fields(cls)}
return cls(**{k: v for k, v in yaml_data.items() if k in field_names})
Expand All @@ -346,11 +347,10 @@ def write_to_directory(self, dataset_infos_dir, overwrite=False, pretty_print=Fa
if os.path.exists(dataset_infos_path):
# for backward compatibility, let's update the JSON file if it exists
with open(dataset_infos_path, "w", encoding="utf-8") as f:
json.dump(
{config_name: asdict(dset_info) for config_name, dset_info in total_dataset_infos.items()},
f,
indent=4 if pretty_print else None,
)
dataset_infos_dict = {
config_name: asdict(dset_info) for config_name, dset_info in total_dataset_infos.items()
}
json.dump(dataset_infos_dict, f, indent=4 if pretty_print else None)
# Dump the infos in the YAML part of the README.md file
if os.path.exists(dataset_readme_path):
dataset_metadata = DatasetMetadata.from_readme(Path(dataset_readme_path))
Expand All @@ -365,6 +365,9 @@ def write_to_directory(self, dataset_infos_dir, overwrite=False, pretty_print=Fa
dataset_metadata["dataset_infos"] = dataset_metadata["dataset_infos"][0]
# no need to include the configuration name when there's only one configuration
dataset_metadata["dataset_infos"].pop("config_name", None)
else:
for config_name, dataset_info_yaml_dict in zip(total_dataset_infos, dataset_metadata["dataset_infos"]):
dataset_info_yaml_dict["config_name"] = config_name
dataset_metadata.to_readme(Path(dataset_readme_path))

@classmethod
Expand All @@ -383,7 +386,7 @@ def from_directory(cls, dataset_infos_dir):
# Load the info from the YAML part of README.md
if os.path.exists(os.path.join(dataset_infos_dir, "README.md")):
dataset_metadata = DatasetMetadata.from_readme(Path(dataset_infos_dir) / "README.md")
if isinstance(dataset_metadata.get("dataset_infos"), (list, dict)) and dataset_metadata["dataset_infos"]:
if isinstance(dataset_metadata.get("dataset_infos"), (list, dict)):
if isinstance(dataset_metadata["dataset_infos"], list):
dataset_infos_dict = {
dataset_info_yaml_dict.get("config_name", "default"): DatasetInfo._from_yaml_dict(
Expand All @@ -392,11 +395,10 @@ def from_directory(cls, dataset_infos_dir):
for dataset_info_yaml_dict in dataset_metadata["dataset_infos"]
}
else:
dataset_infos_dict = {
dataset_metadata["dataset_infos"].get("config_name", "default"): DatasetInfo._from_yaml_dict(
dataset_metadata["dataset_infos"]
)
}
dataset_info = DatasetInfo._from_yaml_dict(dataset_metadata["dataset_infos"])
dataset_info.config_name = dataset_metadata["dataset_infos"].get("config_name", "default")
dataset_infos_dict = {dataset_info.config_name: dataset_info}

return cls(**dataset_infos_dict)


Expand Down
25 changes: 21 additions & 4 deletions src/datasets/splits.py
Expand Up @@ -18,9 +18,10 @@

import abc
import collections
import copy
import dataclasses
import re
from dataclasses import InitVar, dataclass
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

from .arrow_reader import FileInstructions, make_file_instructions
Expand All @@ -33,7 +34,14 @@ class SplitInfo:
name: str = ""
num_bytes: int = 0
num_examples: int = 0
dataset_name: InitVar[Optional[str]] = None # Pseudo-field: ignored by asdict/fields when converting to/from dict

# Deprecated
# For backward compatibility, this field needs to always be included in files like
# dataset_infos.json and dataset_info.json files
# To do so, we always include it in the output of datasets.utils.py_utils.asdict(split_info)
dataset_name: Optional[str] = dataclasses.field(
default=None, metadata={"include_in_asdict_even_if_is_default": True}
)

@property
def file_instructions(self):
Expand Down Expand Up @@ -560,13 +568,22 @@ def from_split_dict(cls, split_infos: Union[List, Dict], dataset_name: Optional[
def to_split_dict(self):
"""Returns a list of SplitInfo protos that we have."""
# Return the SplitInfo, sorted by name
return sorted((s for s in self.values()), key=lambda s: s.name)
out = []
for split_name, split_info in sorted(self.items()):
split_info = copy.deepcopy(split_info)
split_info.name = split_name
out.append(split_info)
return out

def copy(self):
return SplitDict.from_split_dict(self.to_split_dict(), self.dataset_name)

def _to_yaml_list(self) -> list:
return [asdict(s) for s in self.to_split_dict()]
out = [asdict(s) for s in self.to_split_dict()]
# we don't need the dataset_name attribute that is deprecated
for split_info_dict in out:
split_info_dict.pop("dataset_name", None)
return out

@classmethod
def _from_yaml_list(cls, yaml_data: list) -> "SplitDict":
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/utils/py_utils.py
Expand Up @@ -167,7 +167,7 @@ def _asdict_inner(obj):
result = {}
for f in fields(obj):
value = _asdict_inner(getattr(obj, f.name))
if value != f.default or not f.init:
if not f.init or value != f.default or f.metadata.get("include_in_asdict_even_if_is_default", False):
result[f.name] = value
return result
elif isinstance(obj, tuple) and hasattr(obj, "_fields"):
Expand Down
109 changes: 109 additions & 0 deletions tests/test_info.py
@@ -0,0 +1,109 @@
import os

import pytest
import yaml

from datasets.features.features import Features, Value
from datasets.info import DatasetInfo, DatasetInfosDict


@pytest.mark.parametrize(
"dataset_info",
[
DatasetInfo(),
DatasetInfo(
description="foo",
features=Features({"a": Value("int32")}),
builder_name="builder",
config_name="config",
version="1.0.0",
splits=[{"name": "train"}],
download_size=42,
),
],
)
def test_dataset_info_dump_and_reload(tmp_path, dataset_info: DatasetInfo):
tmp_path = str(tmp_path)
dataset_info.write_to_directory(tmp_path)
reloaded = DatasetInfo.from_directory(tmp_path)
assert dataset_info == reloaded
assert os.path.exists(os.path.join(tmp_path, "dataset_info.json"))


def test_dataset_info_to_yaml_dict():
dataset_info = DatasetInfo(
description="foo",
citation="bar",
homepage="https://foo.bar",
license="CC0",
features=Features({"a": Value("int32")}),
post_processed={},
supervised_keys=tuple(),
task_templates=[],
builder_name="builder",
config_name="config",
version="1.0.0",
splits=[{"name": "train", "num_examples": 42}],
download_checksums={},
download_size=1337,
post_processing_size=442,
dataset_size=1234,
size_in_bytes=1337 + 442 + 1234,
)
dataset_info_yaml_dict = dataset_info._to_yaml_dict()
assert sorted(dataset_info_yaml_dict) == sorted(DatasetInfo._INCLUDED_INFO_IN_YAML)
for key in DatasetInfo._INCLUDED_INFO_IN_YAML:
assert key in dataset_info_yaml_dict
assert isinstance(dataset_info_yaml_dict[key], (list, dict, int, str))
dataset_info_yaml = yaml.safe_dump(dataset_info_yaml_dict)
reloaded = yaml.safe_load(dataset_info_yaml)
assert dataset_info_yaml_dict == reloaded


def test_dataset_info_to_yaml_dict_empty():
dataset_info = DatasetInfo()
dataset_info_yaml_dict = dataset_info._to_yaml_dict()
assert dataset_info_yaml_dict == {}


@pytest.mark.parametrize(
"dataset_infos_dict",
[
DatasetInfosDict(),
DatasetInfosDict({"default": DatasetInfo()}),
DatasetInfosDict(
{
"default": DatasetInfo(
description="foo",
features=Features({"a": Value("int32")}),
builder_name="builder",
config_name="config",
version="1.0.0",
splits=[{"name": "train"}],
download_size=42,
)
}
),
DatasetInfosDict(
{
"v1": DatasetInfo(dataset_size=42),
"v2": DatasetInfo(dataset_size=1337),
}
),
],
)
def test_dataset_infos_dict_dump_and_reload(tmp_path, dataset_infos_dict: DatasetInfosDict):
tmp_path = str(tmp_path)
dataset_infos_dict.write_to_directory(tmp_path)
reloaded = DatasetInfosDict.from_directory(tmp_path)

# the config_name of the dataset_infos_dict take over the attribute
for config_name, dataset_info in dataset_infos_dict.items():
dataset_info.config_name = config_name
# the yaml representation doesn't include fields like description or citation
# so we just test that we can recover what we can from the yaml
dataset_infos_dict[config_name] = DatasetInfo._from_yaml_dict(dataset_info._to_yaml_dict())
assert dataset_infos_dict == reloaded

if dataset_infos_dict:
assert os.path.exists(os.path.join(tmp_path, "README.md"))
36 changes: 36 additions & 0 deletions tests/test_splits.py
@@ -0,0 +1,36 @@
import pytest

from datasets.splits import SplitDict, SplitInfo
from datasets.utils.py_utils import asdict


@pytest.mark.parametrize(
"split_dict",
[
SplitDict(),
SplitDict({"train": SplitInfo(name="train", num_bytes=1337, num_examples=42, dataset_name="my_dataset")}),
SplitDict({"train": SplitInfo(name="train", num_bytes=1337, num_examples=42)}),
SplitDict({"train": SplitInfo()}),
],
)
def test_split_dict_to_yaml_list(split_dict: SplitDict):
split_dict_yaml_list = split_dict._to_yaml_list()
assert len(split_dict_yaml_list) == len(split_dict)
reloaded = SplitDict._from_yaml_list(split_dict_yaml_list)
for split_name, split_info in split_dict.items():
# dataset_name field is deprecated, and is therefore not part of the YAML dump
split_info.dataset_name = None
# the split name of split_dict takes over the name of the split info object
split_info.name = split_name
assert split_dict == reloaded


@pytest.mark.parametrize(
"split_info", [SplitInfo(), SplitInfo(dataset_name=None), SplitInfo(dataset_name="my_dataset")]
)
def test_split_dict_asdict_has_dataset_name(split_info):
# For backward compatibility, we need asdict(split_dict) to return split info dictrionaries with the "dataset_name"
# field even if it's deprecated. This way old versionso of `datasets` can still reload dataset_infos.json files
split_dict_asdict = asdict(SplitDict({"train": split_info}))
assert "dataset_name" in split_dict_asdict["train"]
assert split_dict_asdict["train"]["dataset_name"] == split_info.dataset_name

1 comment on commit 3d066ad

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==6.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.008399 / 0.011353 (-0.002954) 0.004144 / 0.011008 (-0.006864) 0.029906 / 0.038508 (-0.008602) 0.036104 / 0.023109 (0.012995) 0.317218 / 0.275898 (0.041320) 0.373114 / 0.323480 (0.049635) 0.006376 / 0.007986 (-0.001610) 0.006028 / 0.004328 (0.001700) 0.007088 / 0.004250 (0.002837) 0.050493 / 0.037052 (0.013441) 0.313466 / 0.258489 (0.054977) 0.355271 / 0.293841 (0.061430) 0.031392 / 0.128546 (-0.097154) 0.009697 / 0.075646 (-0.065950) 0.258427 / 0.419271 (-0.160844) 0.053791 / 0.043533 (0.010258) 0.297224 / 0.255139 (0.042085) 0.326666 / 0.283200 (0.043466) 0.117516 / 0.141683 (-0.024167) 1.440999 / 1.452155 (-0.011155) 1.486845 / 1.492716 (-0.005871)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.245741 / 0.018006 (0.227735) 0.567180 / 0.000490 (0.566690) 0.002257 / 0.000200 (0.002058) 0.000143 / 0.000054 (0.000089)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.023712 / 0.037411 (-0.013699) 0.102219 / 0.014526 (0.087693) 0.113604 / 0.176557 (-0.062953) 0.157399 / 0.737135 (-0.579736) 0.118546 / 0.296338 (-0.177792)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.394132 / 0.215209 (0.178923) 3.929147 / 2.077655 (1.851493) 1.786299 / 1.504120 (0.282180) 1.596985 / 1.541195 (0.055790) 1.667673 / 1.468490 (0.199183) 0.427410 / 4.584777 (-4.157367) 3.762696 / 3.745712 (0.016984) 3.578164 / 5.269862 (-1.691697) 1.726040 / 4.565676 (-2.839636) 0.052146 / 0.424275 (-0.372129) 0.011302 / 0.007607 (0.003695) 0.500424 / 0.226044 (0.274379) 5.039670 / 2.268929 (2.770741) 2.207434 / 55.444624 (-53.237191) 1.878865 / 6.876477 (-4.997612) 2.067892 / 2.142072 (-0.074180) 0.551577 / 4.805227 (-4.253650) 0.120636 / 6.500664 (-6.380028) 0.060865 / 0.075469 (-0.014604)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.500236 / 1.841788 (-0.341552) 14.017181 / 8.074308 (5.942873) 25.425922 / 10.191392 (15.234530) 0.868506 / 0.680424 (0.188082) 0.549439 / 0.534201 (0.015238) 0.385101 / 0.579283 (-0.194182) 0.421230 / 0.434364 (-0.013134) 0.270256 / 0.540337 (-0.270081) 0.268181 / 1.386936 (-1.118755)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.006426 / 0.011353 (-0.004926) 0.004273 / 0.011008 (-0.006736) 0.027933 / 0.038508 (-0.010575) 0.034788 / 0.023109 (0.011679) 0.363691 / 0.275898 (0.087793) 0.448217 / 0.323480 (0.124737) 0.004273 / 0.007986 (-0.003713) 0.003743 / 0.004328 (-0.000585) 0.005040 / 0.004250 (0.000789) 0.043950 / 0.037052 (0.006897) 0.381859 / 0.258489 (0.123370) 0.415619 / 0.293841 (0.121778) 0.030373 / 0.128546 (-0.098173) 0.009985 / 0.075646 (-0.065662) 0.257246 / 0.419271 (-0.162026) 0.054185 / 0.043533 (0.010652) 0.361470 / 0.255139 (0.106331) 0.372161 / 0.283200 (0.088962) 0.109529 / 0.141683 (-0.032154) 1.493575 / 1.452155 (0.041420) 1.520831 / 1.492716 (0.028114)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.249035 / 0.018006 (0.231029) 0.497010 / 0.000490 (0.496521) 0.003808 / 0.000200 (0.003608) 0.000103 / 0.000054 (0.000049)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.024279 / 0.037411 (-0.013133) 0.101972 / 0.014526 (0.087447) 0.115548 / 0.176557 (-0.061009) 0.172648 / 0.737135 (-0.564488) 0.121644 / 0.296338 (-0.174695)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.419392 / 0.215209 (0.204183) 4.186357 / 2.077655 (2.108703) 2.003440 / 1.504120 (0.499320) 1.819134 / 1.541195 (0.277940) 1.917069 / 1.468490 (0.448579) 0.433444 / 4.584777 (-4.151333) 3.744289 / 3.745712 (-0.001424) 2.025747 / 5.269862 (-3.244115) 1.232137 / 4.565676 (-3.333540) 0.052304 / 0.424275 (-0.371971) 0.011100 / 0.007607 (0.003493) 0.521202 / 0.226044 (0.295158) 5.221947 / 2.268929 (2.953018) 2.501250 / 55.444624 (-52.943374) 2.147942 / 6.876477 (-4.728535) 2.330282 / 2.142072 (0.188210) 0.542531 / 4.805227 (-4.262696) 0.123193 / 6.500664 (-6.377471) 0.062231 / 0.075469 (-0.013238)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.520781 / 1.841788 (-0.321006) 14.357662 / 8.074308 (6.283354) 25.118336 / 10.191392 (14.926944) 0.896546 / 0.680424 (0.216122) 0.592922 / 0.534201 (0.058721) 0.386306 / 0.579283 (-0.192978) 0.432824 / 0.434364 (-0.001540) 0.275917 / 0.540337 (-0.264420) 0.284491 / 1.386936 (-1.102445)

CML watermark

Please sign in to comment.