Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mariosasko committed Aug 11, 2022
1 parent 01873fa commit 0e9b25f
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 18 deletions.
81 changes: 63 additions & 18 deletions tests/packaged_modules/test_imagefolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def image_files_with_metadata_that_misses_one_image(tmp_path, image_file):
return str(image_filename), str(image_filename2), str(image_metadata_filename)


@pytest.fixture
def data_files_with_one_split_and_metadata(tmp_path, image_file):
@pytest.fixture(params=["jsonl", "csv"])
def data_files_with_one_split_and_metadata(request, tmp_path, image_file):
data_dir = tmp_path / "imagefolder_data_dir_with_metadata_one_split"
data_dir.mkdir(parents=True, exist_ok=True)
subdir = data_dir / "subdir"
Expand All @@ -113,13 +113,24 @@ def data_files_with_one_split_and_metadata(tmp_path, image_file):
image_filename3 = subdir / "image_rgb3.jpg" # in subdir
shutil.copyfile(image_file, image_filename3)

image_metadata_filename = data_dir / "metadata.jsonl"
image_metadata = textwrap.dedent(
"""\
image_metadata_filename = data_dir / f"metadata.{request.param}"
image_metadata = (
textwrap.dedent(
"""\
{"file_name": "image_rgb.jpg", "caption": "Nice image"}
{"file_name": "image_rgb2.jpg", "caption": "Nice second image"}
{"file_name": "subdir/image_rgb3.jpg", "caption": "Nice third image"}
"""
)
if request.param == "jsonl"
else textwrap.dedent(
"""\
file_name,caption
image_rgb.jpg,Nice image
image_rgb2.jpg,Nice second image
subdir/image_rgb3.jpg,Nice third image
"""
)
)
with open(image_metadata_filename, "w", encoding="utf-8") as f:
f.write(image_metadata)
Expand All @@ -131,8 +142,8 @@ def data_files_with_one_split_and_metadata(tmp_path, image_file):
return data_files_with_one_split_and_metadata


@pytest.fixture
def data_files_with_two_splits_and_metadata(tmp_path, image_file):
@pytest.fixture(params=["jsonl", "csv"])
def data_files_with_two_splits_and_metadata(request, tmp_path, image_file):
data_dir = tmp_path / "imagefolder_data_dir_with_metadata_two_splits"
data_dir.mkdir(parents=True, exist_ok=True)
train_dir = data_dir / "train"
Expand All @@ -147,20 +158,39 @@ def data_files_with_two_splits_and_metadata(tmp_path, image_file):
image_filename3 = test_dir / "image_rgb3.jpg" # test image
shutil.copyfile(image_file, image_filename3)

train_image_metadata_filename = train_dir / "metadata.jsonl"
image_metadata = textwrap.dedent(
"""\
train_image_metadata_filename = train_dir / f"metadata.{request.param}"
image_metadata = (
textwrap.dedent(
"""\
{"file_name": "image_rgb.jpg", "caption": "Nice train image"}
{"file_name": "image_rgb2.jpg", "caption": "Nice second train image"}
"""
)
if request.param == "jsonl"
else textwrap.dedent(
"""\
file_name,caption
image_rgb.jpg,Nice train image
image_rgb2.jpg,Nice second train image
"""
)
)
with open(train_image_metadata_filename, "w", encoding="utf-8") as f:
f.write(image_metadata)
test_image_metadata_filename = test_dir / "metadata.jsonl"
image_metadata = textwrap.dedent(
"""\
test_image_metadata_filename = test_dir / f"metadata.{request.param}"
image_metadata = (
textwrap.dedent(
"""\
{"file_name": "image_rgb3.jpg", "caption": "Nice test image"}
"""
)
if request.param == "jsonl"
else textwrap.dedent(
"""\
file_name,caption
image_rgb3.jpg,Nice test image
"""
)
)
with open(test_image_metadata_filename, "w", encoding="utf-8") as f:
f.write(image_metadata)
Expand Down Expand Up @@ -354,11 +384,26 @@ def test_generate_examples_with_metadata_that_misses_one_image(

@require_pil
@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize("n_splits", [1, 2])
def test_data_files_with_metadata_and_splits(
streaming, cache_dir, n_splits, data_files_with_one_split_and_metadata, data_files_with_two_splits_and_metadata
):
data_files = data_files_with_one_split_and_metadata if n_splits == 1 else data_files_with_two_splits_and_metadata
def test_data_files_with_metadata_and_single_split(streaming, cache_dir, data_files_with_one_split_and_metadata):
data_files = data_files_with_one_split_and_metadata
imagefolder = ImageFolder(data_files=data_files, cache_dir=cache_dir)
imagefolder.download_and_prepare()
datasets = imagefolder.as_streaming_dataset() if streaming else imagefolder.as_dataset()
for split, data_files in data_files.items():
expected_num_of_images = len(data_files) - 1 # don't count the metadata file
assert split in datasets
dataset = list(datasets[split])
assert len(dataset) == expected_num_of_images
# make sure each sample has its own image and metadata
assert len(set(example["image"].filename for example in dataset)) == expected_num_of_images
assert len(set(example["caption"] for example in dataset)) == expected_num_of_images
assert all(example["caption"] is not None for example in dataset)


@require_pil
@pytest.mark.parametrize("streaming", [False, True])
def test_data_files_with_metadata_and_multiple_splits(streaming, cache_dir, data_files_with_two_splits_and_metadata):
data_files = data_files_with_two_splits_and_metadata
imagefolder = ImageFolder(data_files=data_files, cache_dir=cache_dir)
imagefolder.download_and_prepare()
datasets = imagefolder.as_streaming_dataset() if streaming else imagefolder.as_dataset()
Expand Down
3 changes: 3 additions & 0 deletions tests/test_data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,7 @@ def ls(self, path, detail=True, refresh=True, **kwargs):
{"train": "dataset.txt"},
{"train": "data/dataset.txt"},
{"train": ["data/image.jpg", "metadata.jsonl"]},
{"train": ["data/image.jpg", "metadata.csv"]},
# With prefix or suffix in directory or file names
{"train": "my_train_dir/dataset.txt"},
{"train": "data/my_train_file.txt"},
Expand Down Expand Up @@ -615,8 +616,10 @@ def resolver(pattern):
[
# metadata files at the root
["metadata.jsonl"],
["metadata.csv"],
# nested metadata files
["data/metadata.jsonl", "data/train/metadata.jsonl"],
["data/metadata.csv", "data/train/metadata.csv"],
],
)
def test_get_metadata_files_patterns(metadata_files):
Expand Down

0 comments on commit 0e9b25f

Please sign in to comment.