Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mariosasko committed Sep 8, 2022
1 parent 5024d0a commit 0e241dd
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3101,6 +3101,60 @@ def test_dataset_from_text_path_type(path_type, text_path, tmp_path):
_check_text_dataset(dataset, expected_features)


@pytest.fixture
def data_generator():
def _gen():
data = [
{"col_1": "0", "col_2": 0, "col_3": 0.0},
{"col_1": "1", "col_2": 1, "col_3": 1.0},
{"col_1": "2", "col_2": 2, "col_3": 2.0},
{"col_1": "3", "col_2": 3, "col_3": 3.0},
]
for item in data:
yield item

return _gen


def _check_generator_dataset(dataset, expected_features):
assert isinstance(dataset, Dataset)
assert dataset.num_rows == 4
assert dataset.num_columns == 3
assert dataset.column_names == ["col_1", "col_2", "col_3"]
for feature, expected_dtype in expected_features.items():
assert dataset.features[feature].dtype == expected_dtype


@pytest.mark.parametrize("keep_in_memory", [False, True])
def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, tmp_path):
cache_dir = tmp_path / "cache"
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
_check_generator_dataset(dataset, expected_features)


@pytest.mark.parametrize(
"features",
[
None,
{"col_1": "string", "col_2": "int64", "col_3": "float64"},
{"col_1": "string", "col_2": "string", "col_3": "string"},
{"col_1": "int32", "col_2": "int32", "col_3": "int32"},
{"col_1": "float32", "col_2": "float32", "col_3": "float32"},
],
)
def test_dataset_from_generator_features(features, data_generator, tmp_path):
cache_dir = tmp_path / "cache"
default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
expected_features = features.copy() if features else default_expected_features
features = (
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
)
dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir)
_check_generator_dataset(dataset, expected_features)


def test_dataset_to_json(dataset, tmp_path):
file_path = tmp_path / "test_path.jsonl"
bytes_written = dataset.to_json(path_or_buf=file_path)
Expand Down

0 comments on commit 0e241dd

Please sign in to comment.