diff --git a/docs/source/audio_load.mdx b/docs/source/audio_load.mdx
index 2af31c32e9d..44ff5a8d8d8 100644
--- a/docs/source/audio_load.mdx
+++ b/docs/source/audio_load.mdx
@@ -61,23 +61,26 @@ You can also load a dataset with an `AudioFolder` dataset builder. It does not r
## AudioFolder with metadata
-To link your audio files with metadata information, make sure your dataset has a `metadata.jsonl` file. Your dataset structure might look like:
+To link your audio files with metadata information, make sure your dataset has a `metadata.csv` file. Your dataset structure might look like:
```
-folder/train/metadata.jsonl
+folder/train/metadata.csv
folder/train/first_audio_file.mp3
folder/train/second_audio_file.mp3
folder/train/third_audio_file.mp3
```
-Your `metadata.jsonl` file must have a `file_name` column which links audio files with their metadata. An example `metadata.jsonl` file might look like:
+Your `metadata.csv` file must have a `file_name` column which links audio files with their metadata. An example `metadata.csv` file might look like:
-```python
-{"file_name": "first_audio_file.mp3", "transcription": "znowu się duch z ciałem zrośnie w młodocianej wstaniesz wiosnie i możesz skutkiem tych leków umierać wstawać wiek wieków dalej tam były przestrogi jak siekać głowę jak nogi"}
-{"file_name": "second_audio_file.mp3", "transcription": "już u źwierzyńca podwojów król zasiada przy nim książęta i panowie rada a gdzie wzniosły krążył ganek rycerze obok kochanek król skinął palcem zaczęto igrzysko"}
-{"file_name": "third_audio_file.mp3", "transcription": "pewnie kędyś w obłędzie ubite minęły szlaki zaczekajmy dzień jaki poślemy szukać wszędzie dziś jutro pewnie będzie posłali wszędzie sługi czekali dzień i drugi gdy nic nie doczekali z płaczem chcą jechać dali"}
+```text
+file_name,transcription
+first_audio_file.mp3,znowu się duch z ciałem zrośnie w młodocianej wstaniesz wiosnie i możesz skutkiem tych leków umierać wstawać wiek wieków dalej tam były przestrogi jak siekać głowę jak nogi
+second_audio_file.mp3,już u źwierzyńca podwojów król zasiada przy nim książęta i panowie rada a gdzie wzniosły krążył ganek rycerze obok kochanek król skinął palcem zaczęto igrzysko
+third_audio_file.mp3,pewnie kędyś w obłędzie ubite minęły szlaki zaczekajmy dzień jaki poślemy szukać wszędzie dziś jutro pewnie będzie posłali wszędzie sługi czekali dzień i drugi gdy nic nie doczekali z płaczem chcą jechać dali
```
+Metadata can also be specified as JSON Lines, in which case use `metadata.jsonl` as the name of the metadata file. This format is helpful in scenarios when one of the columns is complex, e.g. a list of floats, to avoid parsing errors or reading the complex values as strings.
+
Load your audio dataset by specifying `audiofolder` and the directory containing your data in `data_dir`:
```py
@@ -86,7 +89,7 @@ Load your audio dataset by specifying `audiofolder` and the directory containing
>>> dataset = load_dataset("audiofolder", data_dir="/path/to/folder")
```
-`AudioFolder` will load audio data and create a `transcription` column containing texts from `metadata.jsonl`:
+`AudioFolder` will load audio data and create a `transcription` column containing texts from `metadata.csv`:
```py
>>> dataset["train"][0]
@@ -146,7 +149,7 @@ If you have metadata files inside your data directory, but you still want to inf
-Alternatively, you can add `label` column to your `metadata.jsonl` file.
+Alternatively, you can add `label` column to your `metadata.csv` file.
diff --git a/docs/source/image_load.mdx b/docs/source/image_load.mdx
index 7596a094bc8..aae1fdcc879 100644
--- a/docs/source/image_load.mdx
+++ b/docs/source/image_load.mdx
@@ -82,22 +82,25 @@ Load remote datasets from their URLs with the `data_files` parameter:
## ImageFolder with metadata
-Metadata associated with your dataset can also be loaded, extending the utility of `ImageFolder` to additional vision tasks like image captioning and object detection. Make sure your dataset has a `metadata.jsonl` file:
+Metadata associated with your dataset can also be loaded, extending the utility of `ImageFolder` to additional vision tasks like image captioning and object detection. Make sure your dataset has a `metadata.csv` file:
```
-folder/train/metadata.jsonl
+folder/train/metadata.csv
folder/train/0001.png
folder/train/0002.png
folder/train/0003.png
```
-Your `metadata.jsonl` file must have a `file_name` column which links image files with their metadata:
+Your `metadata.csv` file must have a `file_name` column which links image files with their metadata:
-```jsonl
-{"file_name": "0001.png", "additional_feature": "This is a first value of a text feature you added to your images"}
-{"file_name": "0002.png", "additional_feature": "This is a second value of a text feature you added to your images"}
-{"file_name": "0003.png", "additional_feature": "This is a third value of a text feature you added to your images"}
+```text
+file_name,additional_feature
+0001.png,This is a first value of a text feature you added to your images
+0002.png,This is a second value of a text feature you added to your images
+0003.png,This is a third value of a text feature you added to your images
```
+For complex value types, e.g. a list of floats, it may be more convenient to specify metadata as JSON Lines to avoid parsing errors or reading them as strings. In that case, use `metadata.jsonl` as the name of the metadata file.
+
If metadata files are present, the inferred labels based on the directory name are dropped by default. To include those labels, set `drop_labels=False` in `load_dataset`.
@@ -106,12 +109,13 @@ If metadata files are present, the inferred labels based on the directory name a
### Image captioning
-Image captioning datasets have text describing an image. An example `metadata.jsonl` may look like:
+Image captioning datasets have text describing an image. An example `metadata.csv` may look like:
-```jsonl
-{"file_name": "0001.png", "text": "This is a golden retriever playing with a ball"}
-{"file_name": "0002.png", "text": "A german shepherd"}
-{"file_name": "0003.png", "text": "One chihuahua"}
+```text
+file_name,text
+0001.png,This is a golden retriever playing with a ball
+0002.png,A german shepherd
+0003.png,One chihuahua
```
Load the dataset with `ImageFolder`, and it will create a `text` column for the image captions:
diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py
index dd2098dfdc1..942184d683d 100644
--- a/src/datasets/data_files.py
+++ b/src/datasets/data_files.py
@@ -79,7 +79,12 @@ class Url(str):
DEFAULT_PATTERNS_SPLIT_IN_DIR_NAME,
DEFAULT_PATTERNS_ALL,
]
-METADATA_PATTERNS = ["metadata.jsonl", "**/metadata.jsonl"] # metadata file for ImageFolder and AudioFolder
+METADATA_PATTERNS = [
+ "metadata.csv",
+ "**/metadata.csv",
+ "metadata.jsonl",
+ "**/metadata.jsonl",
+] # metadata file for ImageFolder and AudioFolder
WILDCARD_CHARACTERS = "*[]"
FILES_TO_IGNORE = ["README.md", "config.json", "dataset_infos.json", "dummy_data.zip", "dataset_dict.json"]
diff --git a/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py b/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py
index 379019ae083..f7a09d7c05d 100644
--- a/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py
+++ b/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py
@@ -4,6 +4,8 @@
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple
+import pandas as pd
+import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.json as paj
@@ -68,7 +70,7 @@ class FolderBasedBuilder(datasets.GeneratorBasedBuilder):
EXTENSIONS: List[str]
SKIP_CHECKSUM_COMPUTATION_BY_DEFAULT: bool = True
- METADATA_FILENAME: str = "metadata.jsonl"
+ METADATA_FILENAMES: List[str] = ["metadata.csv", "metadata.jsonl"]
def _info(self):
return datasets.DatasetInfo(features=self.config.features)
@@ -97,12 +99,12 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split):
if original_file_ext.lower() in self.EXTENSIONS:
if not self.config.drop_labels:
labels.add(os.path.basename(os.path.dirname(original_file)))
- elif os.path.basename(original_file) == self.METADATA_FILENAME:
+ elif os.path.basename(original_file) in self.METADATA_FILENAMES:
metadata_files[split].add((original_file, downloaded_file))
else:
original_file_name = os.path.basename(original_file)
logger.debug(
- f"The file '{original_file_name}' was ignored: it is not an {self.BASE_COLUMN_NAME}, and is not {self.METADATA_FILENAME} either."
+ f"The file '{original_file_name}' was ignored: it is not an image, and is not {self.METADATA_FILENAMES} either."
)
else:
archives, downloaded_dirs = files_or_archives, downloaded_files_or_dirs
@@ -113,13 +115,13 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split):
if downloaded_dir_file_ext in self.EXTENSIONS:
if not self.config.drop_labels:
labels.add(os.path.basename(os.path.dirname(downloaded_dir_file)))
- elif os.path.basename(downloaded_dir_file) == self.METADATA_FILENAME:
+ elif os.path.basename(downloaded_dir_file) in self.METADATA_FILENAMES:
metadata_files[split].add((None, downloaded_dir_file))
else:
archive_file_name = os.path.basename(archive)
original_file_name = os.path.basename(downloaded_dir_file)
logger.debug(
- f"The file '{original_file_name}' from the archive '{archive_file_name}' was ignored: it is not an {self.BASE_COLUMN_NAME}, and is not {self.METADATA_FILENAME} either."
+ f"The file '{original_file_name}' from the archive '{archive_file_name}' was ignored: it is not an {self.BASE_COLUMN_NAME}, and is not {self.METADATA_FILENAMES} either."
)
data_files = self.config.data_files
@@ -173,9 +175,18 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split):
# * all metadata files have the same set of features
# * the `file_name` key is one of the metadata keys and is of type string
features_per_metadata_file: List[Tuple[str, datasets.Features]] = []
+
+ # Check that all metadata files share the same format
+ metadata_ext = set(
+ os.path.splitext(downloaded_metadata_file)[1][1:]
+ for _, downloaded_metadata_file in itertools.chain.from_iterable(metadata_files.values())
+ )
+ if len(metadata_ext) > 1:
+ raise ValueError(f"Found metadata files with different extensions: {list(metadata_ext)}")
+ metadata_ext = metadata_ext.pop()
+
for _, downloaded_metadata_file in itertools.chain.from_iterable(metadata_files.values()):
- with open(downloaded_metadata_file, "rb") as f:
- pa_metadata_table = paj.read_json(f)
+ pa_metadata_table = self._read_metadata(downloaded_metadata_file)
features_per_metadata_file.append(
(downloaded_metadata_file, datasets.Features.from_arrow_schema(pa_metadata_table.schema))
)
@@ -232,12 +243,21 @@ def _split_files_and_archives(self, data_files):
_, data_file_ext = os.path.splitext(data_file)
if data_file_ext.lower() in self.EXTENSIONS:
files.append(data_file)
- elif os.path.basename(data_file) == self.METADATA_FILENAME:
+ elif os.path.basename(data_file) in self.METADATA_FILENAMES:
files.append(data_file)
else:
archives.append(data_file)
return files, archives
+ def _read_metadata(self, metadata_file):
+ metadata_file_ext = os.path.splitext(metadata_file)[1][1:]
+ if metadata_file_ext == "csv":
+ # Use `pd.read_csv` (although slower) instead of `pyarrow.csv.read_csv` for reading CSV files for consistency with the CSV packaged module
+ return pa.Table.from_pandas(pd.read_csv(metadata_file))
+ else:
+ with open(metadata_file, "rb") as f:
+ return paj.read_json(f)
+
def _generate_examples(self, files, metadata_files, split_name, add_metadata, add_labels):
split_metadata_files = metadata_files.get(split_name, [])
sample_empty_metadata = (
@@ -248,6 +268,13 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
metadata_dict = None
downloaded_metadata_file = None
+ if split_metadata_files:
+ metadata_ext = set(
+ os.path.splitext(downloaded_metadata_file)[1][1:]
+ for _, downloaded_metadata_file in split_metadata_files
+ )
+ metadata_ext = metadata_ext.pop()
+
file_idx = 0
for original_file, downloaded_file_or_dir in files:
if original_file is not None:
@@ -276,8 +303,7 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
_, metadata_file, downloaded_metadata_file = min(
metadata_file_candidates, key=lambda x: count_path_segments(x[0])
)
- with open(downloaded_metadata_file, "rb") as f:
- pa_metadata_table = paj.read_json(f)
+ pa_metadata_table = self._read_metadata(downloaded_metadata_file)
pa_file_name_array = pa_metadata_table["file_name"]
pa_file_name_array = pc.replace_substring(
pa_file_name_array, pattern="\\", replacement="/"
@@ -292,7 +318,7 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
}
else:
raise ValueError(
- f"One or several metadata.jsonl were found, but not in the same directory or in a parent directory of {downloaded_file_or_dir}."
+ f"One or several metadata.{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_file_or_dir}."
)
if metadata_dir is not None and downloaded_metadata_file is not None:
file_relpath = os.path.relpath(original_file, metadata_dir)
@@ -304,7 +330,7 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
sample_metadata = metadata_dict[file_relpath]
else:
raise ValueError(
- f"One or several metadata.jsonl were found, but not in the same directory or in a parent directory of {downloaded_file_or_dir}."
+ f"One or several metadata.{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_file_or_dir}."
)
else:
sample_metadata = {}
@@ -346,8 +372,7 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
_, metadata_file, downloaded_metadata_file = min(
metadata_file_candidates, key=lambda x: count_path_segments(x[0])
)
- with open(downloaded_metadata_file, "rb") as f:
- pa_metadata_table = paj.read_json(f)
+ pa_metadata_table = self._read_metadata(downloaded_metadata_file)
pa_file_name_array = pa_metadata_table["file_name"]
pa_file_name_array = pc.replace_substring(
pa_file_name_array, pattern="\\", replacement="/"
@@ -362,7 +387,7 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
}
else:
raise ValueError(
- f"One or several metadata.jsonl were found, but not in the same directory or in a parent directory of {downloaded_dir_file}."
+ f"One or several metadata.{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_dir_file}."
)
if metadata_dir is not None and downloaded_metadata_file is not None:
downloaded_dir_file_relpath = os.path.relpath(downloaded_dir_file, metadata_dir)
@@ -374,7 +399,7 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
sample_metadata = metadata_dict[downloaded_dir_file_relpath]
else:
raise ValueError(
- f"One or several metadata.jsonl were found, but not in the same directory or in a parent directory of {downloaded_dir_file}."
+ f"One or several metadata.{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_dir_file}."
)
else:
sample_metadata = {}
diff --git a/tests/packaged_modules/test_audiofolder.py b/tests/packaged_modules/test_audiofolder.py
index 25a0f8141fe..9141251beb9 100644
--- a/tests/packaged_modules/test_audiofolder.py
+++ b/tests/packaged_modules/test_audiofolder.py
@@ -132,8 +132,8 @@ def data_files_with_one_split_and_metadata(tmp_path, audio_file):
return data_files_with_one_split_and_metadata
-@pytest.fixture
-def data_files_with_two_splits_and_metadata(tmp_path, audio_file):
+@pytest.fixture(params=["jsonl", "csv"])
+def data_files_with_two_splits_and_metadata(request, tmp_path, audio_file):
data_dir = tmp_path / "audiofolder_data_dir_with_metadata"
data_dir.mkdir(parents=True, exist_ok=True)
train_dir = data_dir / "train"
@@ -148,20 +148,39 @@ def data_files_with_two_splits_and_metadata(tmp_path, audio_file):
audio_filename3 = test_dir / "audio_file3.wav" # test audio
shutil.copyfile(audio_file, audio_filename3)
- train_audio_metadata_filename = train_dir / "metadata.jsonl"
- audio_metadata = textwrap.dedent(
- """\
+ train_audio_metadata_filename = train_dir / f"metadata.{request.param}"
+ audio_metadata = (
+ textwrap.dedent(
+ """\
{"file_name": "audio_file.wav", "text": "First train audio transcription"}
{"file_name": "audio_file2.wav", "text": "Second train audio transcription"}
"""
+ )
+ if request.param == "jsonl"
+ else textwrap.dedent(
+ """\
+ file_name,text
+ audio_file.wav,First train audio transcription
+ audio_file2.wav,Second train audio transcription
+ """
+ )
)
with open(train_audio_metadata_filename, "w", encoding="utf-8") as f:
f.write(audio_metadata)
- test_audio_metadata_filename = test_dir / "metadata.jsonl"
- audio_metadata = textwrap.dedent(
- """\
+ test_audio_metadata_filename = test_dir / f"metadata.{request.param}"
+ audio_metadata = (
+ textwrap.dedent(
+ """\
{"file_name": "audio_file3.wav", "text": "Test audio transcription"}
"""
+ )
+ if request.param == "jsonl"
+ else textwrap.dedent(
+ """\
+ file_name,text
+ audio_file3.wav,Test audio transcription
+ """
+ )
)
with open(test_audio_metadata_filename, "w", encoding="utf-8") as f:
f.write(audio_metadata)
@@ -357,11 +376,26 @@ def test_generate_examples_with_metadata_that_misses_one_audio(
@require_sndfile
@pytest.mark.parametrize("streaming", [False, True])
-@pytest.mark.parametrize("n_splits", [1, 2])
-def test_data_files_with_metadata_and_splits(
- streaming, cache_dir, n_splits, data_files_with_one_split_and_metadata, data_files_with_two_splits_and_metadata
-):
- data_files = data_files_with_one_split_and_metadata if n_splits == 1 else data_files_with_two_splits_and_metadata
+def test_data_files_with_metadata_and_single_split(streaming, cache_dir, data_files_with_one_split_and_metadata):
+ data_files = data_files_with_one_split_and_metadata
+ audiofolder = AudioFolder(data_files=data_files, cache_dir=cache_dir)
+ audiofolder.download_and_prepare()
+ datasets = audiofolder.as_streaming_dataset() if streaming else audiofolder.as_dataset()
+ for split, data_files in data_files.items():
+ expected_num_of_audios = len(data_files) - 1 # don't count the metadata file
+ assert split in datasets
+ dataset = list(datasets[split])
+ assert len(dataset) == expected_num_of_audios
+ # make sure each sample has its own audio and metadata
+ assert len(set(example["audio"]["path"] for example in dataset)) == expected_num_of_audios
+ assert len(set(example["text"] for example in dataset)) == expected_num_of_audios
+ assert all(example["text"] is not None for example in dataset)
+
+
+@require_sndfile
+@pytest.mark.parametrize("streaming", [False, True])
+def test_data_files_with_metadata_and_multiple_splits(streaming, cache_dir, data_files_with_two_splits_and_metadata):
+ data_files = data_files_with_two_splits_and_metadata
audiofolder = AudioFolder(data_files=data_files, cache_dir=cache_dir)
audiofolder.download_and_prepare()
datasets = audiofolder.as_streaming_dataset() if streaming else audiofolder.as_dataset()
@@ -442,3 +476,33 @@ def test_data_files_with_wrong_audio_file_name_column_in_metadata_file(cache_dir
with pytest.raises(ValueError) as exc_info:
audiofolder.download_and_prepare()
assert "`file_name` must be present" in str(exc_info.value)
+
+
+@require_sndfile
+def test_data_files_with_with_metadata_in_different_formats(cache_dir, tmp_path, audio_file):
+ data_dir = tmp_path / "data_dir_with_metadata_in_different_format"
+ data_dir.mkdir(parents=True, exist_ok=True)
+ shutil.copyfile(audio_file, data_dir / "audio_file.wav")
+ audio_metadata_filename_jsonl = data_dir / "metadata.jsonl"
+ audio_metadata_jsonl = textwrap.dedent(
+ """\
+ {"file_name": "audio_file.wav", "text": "Audio transcription"}
+ """
+ )
+ with open(audio_metadata_filename_jsonl, "w", encoding="utf-8") as f:
+ f.write(audio_metadata_jsonl)
+ audio_metadata_filename_csv = data_dir / "metadata.csv"
+ audio_metadata_csv = textwrap.dedent(
+ """\
+ file_name,text
+ audio_file.wav,Audio transcription
+ """
+ )
+ with open(audio_metadata_filename_csv, "w", encoding="utf-8") as f:
+ f.write(audio_metadata_csv)
+
+ data_files_with_bad_metadata = DataFilesDict.from_local_or_remote(get_data_patterns_locally(data_dir), data_dir)
+ audiofolder = AudioFolder(data_files=data_files_with_bad_metadata, cache_dir=cache_dir)
+ with pytest.raises(ValueError) as exc_info:
+ audiofolder.download_and_prepare()
+ assert "metadata files with different extensions" in str(exc_info.value)
diff --git a/tests/packaged_modules/test_imagefolder.py b/tests/packaged_modules/test_imagefolder.py
index 1c1c4b6752a..71f21702f2b 100644
--- a/tests/packaged_modules/test_imagefolder.py
+++ b/tests/packaged_modules/test_imagefolder.py
@@ -98,8 +98,8 @@ def image_files_with_metadata_that_misses_one_image(tmp_path, image_file):
return str(image_filename), str(image_filename2), str(image_metadata_filename)
-@pytest.fixture
-def data_files_with_one_split_and_metadata(tmp_path, image_file):
+@pytest.fixture(params=["jsonl", "csv"])
+def data_files_with_one_split_and_metadata(request, tmp_path, image_file):
data_dir = tmp_path / "imagefolder_data_dir_with_metadata_one_split"
data_dir.mkdir(parents=True, exist_ok=True)
subdir = data_dir / "subdir"
@@ -112,13 +112,24 @@ def data_files_with_one_split_and_metadata(tmp_path, image_file):
image_filename3 = subdir / "image_rgb3.jpg" # in subdir
shutil.copyfile(image_file, image_filename3)
- image_metadata_filename = data_dir / "metadata.jsonl"
- image_metadata = textwrap.dedent(
- """\
+ image_metadata_filename = data_dir / f"metadata.{request.param}"
+ image_metadata = (
+ textwrap.dedent(
+ """\
{"file_name": "image_rgb.jpg", "caption": "Nice image"}
{"file_name": "image_rgb2.jpg", "caption": "Nice second image"}
{"file_name": "subdir/image_rgb3.jpg", "caption": "Nice third image"}
"""
+ )
+ if request.param == "jsonl"
+ else textwrap.dedent(
+ """\
+ file_name,caption
+ image_rgb.jpg,Nice image
+ image_rgb2.jpg,Nice second image
+ subdir/image_rgb3.jpg,Nice third image
+ """
+ )
)
with open(image_metadata_filename, "w", encoding="utf-8") as f:
f.write(image_metadata)
@@ -130,8 +141,8 @@ def data_files_with_one_split_and_metadata(tmp_path, image_file):
return data_files_with_one_split_and_metadata
-@pytest.fixture
-def data_files_with_two_splits_and_metadata(tmp_path, image_file):
+@pytest.fixture(params=["jsonl", "csv"])
+def data_files_with_two_splits_and_metadata(request, tmp_path, image_file):
data_dir = tmp_path / "imagefolder_data_dir_with_metadata_two_splits"
data_dir.mkdir(parents=True, exist_ok=True)
train_dir = data_dir / "train"
@@ -146,20 +157,39 @@ def data_files_with_two_splits_and_metadata(tmp_path, image_file):
image_filename3 = test_dir / "image_rgb3.jpg" # test image
shutil.copyfile(image_file, image_filename3)
- train_image_metadata_filename = train_dir / "metadata.jsonl"
- image_metadata = textwrap.dedent(
- """\
+ train_image_metadata_filename = train_dir / f"metadata.{request.param}"
+ image_metadata = (
+ textwrap.dedent(
+ """\
{"file_name": "image_rgb.jpg", "caption": "Nice train image"}
{"file_name": "image_rgb2.jpg", "caption": "Nice second train image"}
"""
+ )
+ if request.param == "jsonl"
+ else textwrap.dedent(
+ """\
+ file_name,caption
+ image_rgb.jpg,Nice train image
+ image_rgb2.jpg,Nice second train image
+ """
+ )
)
with open(train_image_metadata_filename, "w", encoding="utf-8") as f:
f.write(image_metadata)
- test_image_metadata_filename = test_dir / "metadata.jsonl"
- image_metadata = textwrap.dedent(
- """\
+ test_image_metadata_filename = test_dir / f"metadata.{request.param}"
+ image_metadata = (
+ textwrap.dedent(
+ """\
{"file_name": "image_rgb3.jpg", "caption": "Nice test image"}
"""
+ )
+ if request.param == "jsonl"
+ else textwrap.dedent(
+ """\
+ file_name,caption
+ image_rgb3.jpg,Nice test image
+ """
+ )
)
with open(test_image_metadata_filename, "w", encoding="utf-8") as f:
f.write(image_metadata)
@@ -353,11 +383,26 @@ def test_generate_examples_with_metadata_that_misses_one_image(
@require_pil
@pytest.mark.parametrize("streaming", [False, True])
-@pytest.mark.parametrize("n_splits", [1, 2])
-def test_data_files_with_metadata_and_splits(
- streaming, cache_dir, n_splits, data_files_with_one_split_and_metadata, data_files_with_two_splits_and_metadata
-):
- data_files = data_files_with_one_split_and_metadata if n_splits == 1 else data_files_with_two_splits_and_metadata
+def test_data_files_with_metadata_and_single_split(streaming, cache_dir, data_files_with_one_split_and_metadata):
+ data_files = data_files_with_one_split_and_metadata
+ imagefolder = ImageFolder(data_files=data_files, cache_dir=cache_dir)
+ imagefolder.download_and_prepare()
+ datasets = imagefolder.as_streaming_dataset() if streaming else imagefolder.as_dataset()
+ for split, data_files in data_files.items():
+ expected_num_of_images = len(data_files) - 1 # don't count the metadata file
+ assert split in datasets
+ dataset = list(datasets[split])
+ assert len(dataset) == expected_num_of_images
+ # make sure each sample has its own image and metadata
+ assert len(set(example["image"].filename for example in dataset)) == expected_num_of_images
+ assert len(set(example["caption"] for example in dataset)) == expected_num_of_images
+ assert all(example["caption"] is not None for example in dataset)
+
+
+@require_pil
+@pytest.mark.parametrize("streaming", [False, True])
+def test_data_files_with_metadata_and_multiple_splits(streaming, cache_dir, data_files_with_two_splits_and_metadata):
+ data_files = data_files_with_two_splits_and_metadata
imagefolder = ImageFolder(data_files=data_files, cache_dir=cache_dir)
imagefolder.download_and_prepare()
datasets = imagefolder.as_streaming_dataset() if streaming else imagefolder.as_dataset()
@@ -431,3 +476,33 @@ def test_data_files_with_wrong_image_file_name_column_in_metadata_file(cache_dir
with pytest.raises(ValueError) as exc_info:
imagefolder.download_and_prepare()
assert "`file_name` must be present" in str(exc_info.value)
+
+
+@require_pil
+def test_data_files_with_with_metadata_in_different_formats(cache_dir, tmp_path, image_file):
+ data_dir = tmp_path / "data_dir_with_metadata_in_different_format"
+ data_dir.mkdir(parents=True, exist_ok=True)
+ shutil.copyfile(image_file, data_dir / "image_rgb.jpg")
+ image_metadata_filename_jsonl = data_dir / "metadata.jsonl"
+ image_metadata_jsonl = textwrap.dedent(
+ """\
+ {"file_name": "image_rgb.jpg", "caption": "Nice image"}
+ """
+ )
+ with open(image_metadata_filename_jsonl, "w", encoding="utf-8") as f:
+ f.write(image_metadata_jsonl)
+ image_metadata_filename_csv = data_dir / "metadata.csv"
+ image_metadata_csv = textwrap.dedent(
+ """\
+ file_name,caption
+ image_rgb.jpg,Nice image
+ """
+ )
+ with open(image_metadata_filename_csv, "w", encoding="utf-8") as f:
+ f.write(image_metadata_csv)
+
+ data_files_with_bad_metadata = DataFilesDict.from_local_or_remote(get_data_patterns_locally(data_dir), data_dir)
+ imagefolder = ImageFolder(data_files=data_files_with_bad_metadata, cache_dir=cache_dir)
+ with pytest.raises(ValueError) as exc_info:
+ imagefolder.download_and_prepare()
+ assert "metadata files with different extensions" in str(exc_info.value)
diff --git a/tests/test_data_files.py b/tests/test_data_files.py
index 62082a3ff6f..1679f7ee169 100644
--- a/tests/test_data_files.py
+++ b/tests/test_data_files.py
@@ -566,6 +566,7 @@ def ls(self, path, detail=True, refresh=True, **kwargs):
{"train": "dataset.txt"},
{"train": "data/dataset.txt"},
{"train": ["data/image.jpg", "metadata.jsonl"]},
+ {"train": ["data/image.jpg", "metadata.csv"]},
# With prefix or suffix in directory or file names
{"train": "my_train_dir/dataset.txt"},
{"train": "data/my_train_file.txt"},
@@ -615,8 +616,10 @@ def resolver(pattern):
[
# metadata files at the root
["metadata.jsonl"],
+ ["metadata.csv"],
# nested metadata files
["data/metadata.jsonl", "data/train/metadata.jsonl"],
+ ["data/metadata.csv", "data/train/metadata.csv"],
],
)
def test_get_metadata_files_patterns(metadata_files):