From 583bb7d6a42f48e9a959023f9cdbfa3402b09027 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20=C5=A0a=C5=A1ko?= Date: Mon, 3 Oct 2022 14:21:31 +0200 Subject: [PATCH] Revert task removal in folder-based builders (#5051) * Add AudioClassification task * Add classification task to folder based builders * Fix tests * Minor fix * Minor fix again --- .../package_reference/task_templates.mdx | 2 ++ .../audiofolder/audiofolder.py | 2 ++ .../folder_based_builder.py | 9 +++-- .../imagefolder/imagefolder.py | 2 ++ src/datasets/tasks/__init__.py | 3 ++ src/datasets/tasks/audio_classificiation.py | 33 +++++++++++++++++++ .../test_folder_based_builder.py | 2 ++ tests/test_tasks.py | 28 ++++++++++++++++ 8 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 src/datasets/tasks/audio_classificiation.py diff --git a/docs/source/package_reference/task_templates.mdx b/docs/source/package_reference/task_templates.mdx index ec8bcae0217..52d275b8531 100644 --- a/docs/source/package_reference/task_templates.mdx +++ b/docs/source/package_reference/task_templates.mdx @@ -4,6 +4,8 @@ The tasks supported by [`Dataset.prepare_for_task`] and [`DatasetDict.prepare_fo [[autodoc]] datasets.tasks.AutomaticSpeechRecognition +[[autodoc]] datasets.tasks.AudioClassification + [[autodoc]] datasets.tasks.ImageClassification - align_with_features diff --git a/src/datasets/packaged_modules/audiofolder/audiofolder.py b/src/datasets/packaged_modules/audiofolder/audiofolder.py index ab90d8de378..e408804aa46 100644 --- a/src/datasets/packaged_modules/audiofolder/audiofolder.py +++ b/src/datasets/packaged_modules/audiofolder/audiofolder.py @@ -1,6 +1,7 @@ from typing import List import datasets +from datasets.tasks import AudioClassification from ..folder_based_builder import folder_based_builder @@ -20,6 +21,7 @@ class AudioFolder(folder_based_builder.FolderBasedBuilder): BASE_COLUMN_NAME = "audio" BUILDER_CONFIG_CLASS = AudioFolderConfig EXTENSIONS: List[str] # definition at the bottom of the script + CLASSIFICATION_TASK = AudioClassification(audio_column="audio", label_column="label") # Obtained with: diff --git a/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py b/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py index f7a09d7c05d..6c96d402aa1 100644 --- a/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py +++ b/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py @@ -2,7 +2,7 @@ import itertools import os from dataclasses import dataclass -from typing import Any, List, Optional, Tuple +from typing import List, Optional, Tuple import pandas as pd import pyarrow as pa @@ -10,6 +10,8 @@ import pyarrow.json as paj import datasets +from datasets.features.features import FeatureType +from datasets.tasks.base import TaskTemplate logger = datasets.utils.logging.get_logger(__name__) @@ -62,12 +64,14 @@ class FolderBasedBuilder(datasets.GeneratorBasedBuilder): BUILDER_CONFIG_CLASS: builder config inherited from `folder_based_builder.FolderBasedBuilderConfig` EXTENSIONS: list of allowed extensions (only files with these extensions and METADATA_FILENAME files will be included in a dataset) + CLASSIFICATION_TASK: classification task to use if labels are obtained from the folder structure """ - BASE_FEATURE: Any + BASE_FEATURE: FeatureType BASE_COLUMN_NAME: str BUILDER_CONFIG_CLASS: FolderBasedBuilderConfig EXTENSIONS: List[str] + CLASSIFICATION_TASK: TaskTemplate SKIP_CHECKSUM_COMPUTATION_BY_DEFAULT: bool = True METADATA_FILENAMES: List[str] = ["metadata.csv", "metadata.jsonl"] @@ -214,6 +218,7 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split): "label": datasets.ClassLabel(names=sorted(labels)), } ) + self.info.task_templates = [self.CLASSIFICATION_TASK.align_with_features(self.info.features)] else: self.info.features = datasets.Features({self.BASE_COLUMN_NAME: self.BASE_FEATURE}) diff --git a/src/datasets/packaged_modules/imagefolder/imagefolder.py b/src/datasets/packaged_modules/imagefolder/imagefolder.py index d4e25866750..80a5051210d 100644 --- a/src/datasets/packaged_modules/imagefolder/imagefolder.py +++ b/src/datasets/packaged_modules/imagefolder/imagefolder.py @@ -1,6 +1,7 @@ from typing import List import datasets +from datasets.tasks import ImageClassification from ..folder_based_builder import folder_based_builder @@ -20,6 +21,7 @@ class ImageFolder(folder_based_builder.FolderBasedBuilder): BASE_COLUMN_NAME = "image" BUILDER_CONFIG_CLASS = ImageFolderConfig EXTENSIONS: List[str] # definition at the bottom of the script + CLASSIFICATION_TASK = ImageClassification(image_column="image", label_column="label") # Obtained with: diff --git a/src/datasets/tasks/__init__.py b/src/datasets/tasks/__init__.py index 7ff448b72e4..f84db2b79c8 100644 --- a/src/datasets/tasks/__init__.py +++ b/src/datasets/tasks/__init__.py @@ -1,6 +1,7 @@ from typing import Optional from ..utils.logging import get_logger +from .audio_classificiation import AudioClassification from .automatic_speech_recognition import AutomaticSpeechRecognition from .base import TaskTemplate from .image_classification import ImageClassification @@ -12,6 +13,7 @@ __all__ = [ "AutomaticSpeechRecognition", + "AudioClassification", "ImageClassification", "LanguageModeling", "QuestionAnsweringExtractive", @@ -25,6 +27,7 @@ NAME2TEMPLATE = { AutomaticSpeechRecognition.task: AutomaticSpeechRecognition, + AudioClassification.task: AudioClassification, ImageClassification.task: ImageClassification, LanguageModeling.task: LanguageModeling, QuestionAnsweringExtractive.task: QuestionAnsweringExtractive, diff --git a/src/datasets/tasks/audio_classificiation.py b/src/datasets/tasks/audio_classificiation.py new file mode 100644 index 00000000000..e0a2370d03d --- /dev/null +++ b/src/datasets/tasks/audio_classificiation.py @@ -0,0 +1,33 @@ +import copy +from dataclasses import dataclass +from typing import ClassVar, Dict + +from ..features import Audio, ClassLabel, Features +from .base import TaskTemplate + + +@dataclass(frozen=True) +class AudioClassification(TaskTemplate): + task: str = "audio-classification" + input_schema: ClassVar[Features] = Features({"audio": Audio()}) + label_schema: ClassVar[Features] = Features({"labels": ClassLabel}) + audio_column: str = "audio" + label_column: str = "labels" + + def align_with_features(self, features): + if self.label_column not in features: + raise ValueError(f"Column {self.label_column} is not present in features.") + if not isinstance(features[self.label_column], ClassLabel): + raise ValueError(f"Column {self.label_column} is not a ClassLabel.") + task_template = copy.deepcopy(self) + label_schema = self.label_schema.copy() + label_schema["labels"] = features[self.label_column] + task_template.__dict__["label_schema"] = label_schema + return task_template + + @property + def column_mapping(self) -> Dict[str, str]: + return { + self.audio_column: "audio", + self.label_column: "labels", + } diff --git a/tests/packaged_modules/test_folder_based_builder.py b/tests/packaged_modules/test_folder_based_builder.py index 51b5d46d8f9..34fe3a62db7 100644 --- a/tests/packaged_modules/test_folder_based_builder.py +++ b/tests/packaged_modules/test_folder_based_builder.py @@ -11,6 +11,7 @@ FolderBasedBuilder, FolderBasedBuilderConfig, ) +from datasets.tasks import TextClassification class DummyFolderBasedBuilder(FolderBasedBuilder): @@ -18,6 +19,7 @@ class DummyFolderBasedBuilder(FolderBasedBuilder): BASE_COLUMN_NAME = "base" BUILDER_CONFIG_CLASS = FolderBasedBuilderConfig EXTENSIONS = [".txt"] + CLASSIFICATION_TASK = TextClassification(text_column="base", label_column="label") @pytest.fixture diff --git a/tests/test_tasks.py b/tests/test_tasks.py index ce6bf41e20c..a2877ce9185 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -5,6 +5,7 @@ from datasets.features import Audio, ClassLabel, Features, Image, Sequence, Value from datasets.info import DatasetInfo from datasets.tasks import ( + AudioClassification, AutomaticSpeechRecognition, ImageClassification, LanguageModeling, @@ -126,6 +127,33 @@ def test_from_dict(self): self.assertEqual(label_schema, task.label_schema) +class AudioClassificationTest(TestCase): + def setUp(self): + self.labels = sorted(["pos", "neg"]) + + def test_column_mapping(self): + task = AudioClassification(audio_column="input_audio", label_column="input_label") + self.assertDictEqual({"input_audio": "audio", "input_label": "labels"}, task.column_mapping) + + def test_from_dict(self): + input_schema = Features({"audio": Audio()}) + label_schema = Features({"labels": ClassLabel}) + template_dict = { + "audio_column": "input_image", + "label_column": "input_label", + } + task = AudioClassification.from_dict(template_dict) + self.assertEqual("audio-classification", task.task) + self.assertEqual(input_schema, task.input_schema) + self.assertEqual(label_schema, task.label_schema) + + def test_align_with_features(self): + task = AudioClassification(audio_column="input_audio", label_column="input_label") + self.assertEqual(task.label_schema["labels"], ClassLabel) + task = task.align_with_features(Features({"input_label": ClassLabel(names=self.labels)})) + self.assertEqual(task.label_schema["labels"], ClassLabel(names=self.labels)) + + class ImageClassificationTest(TestCase): def setUp(self): self.labels = sorted(["pos", "neg"])