diff --git a/src/datasets/tasks/audio_classificiation.py b/src/datasets/tasks/audio_classificiation.py index e0a2370d03d..6f9fe402f38 100644 --- a/src/datasets/tasks/audio_classificiation.py +++ b/src/datasets/tasks/audio_classificiation.py @@ -1,5 +1,5 @@ import copy -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import ClassVar, Dict from ..features import Audio, ClassLabel, Features @@ -8,7 +8,7 @@ @dataclass(frozen=True) class AudioClassification(TaskTemplate): - task: str = "audio-classification" + task: str = field(default="audio-classification", metadata={"include_in_asdict_even_if_is_default": True}) input_schema: ClassVar[Features] = Features({"audio": Audio()}) label_schema: ClassVar[Features] = Features({"labels": ClassLabel}) audio_column: str = "audio" diff --git a/src/datasets/tasks/automatic_speech_recognition.py b/src/datasets/tasks/automatic_speech_recognition.py index cccd8020e69..103a98a1bc9 100644 --- a/src/datasets/tasks/automatic_speech_recognition.py +++ b/src/datasets/tasks/automatic_speech_recognition.py @@ -1,5 +1,5 @@ import copy -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import ClassVar, Dict from ..features import Audio, Features, Value @@ -8,7 +8,7 @@ @dataclass(frozen=True) class AutomaticSpeechRecognition(TaskTemplate): - task: str = "automatic-speech-recognition" + task: str = field(default="automatic-speech-recognition", metadata={"include_in_asdict_even_if_is_default": True}) input_schema: ClassVar[Features] = Features({"audio": Audio()}) label_schema: ClassVar[Features] = Features({"transcription": Value("string")}) audio_column: str = "audio" diff --git a/src/datasets/tasks/image_classification.py b/src/datasets/tasks/image_classification.py index 251daaf6634..20a19e0408a 100644 --- a/src/datasets/tasks/image_classification.py +++ b/src/datasets/tasks/image_classification.py @@ -1,5 +1,5 @@ import copy -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import ClassVar, Dict from ..features import ClassLabel, Features, Image @@ -8,7 +8,7 @@ @dataclass(frozen=True) class ImageClassification(TaskTemplate): - task: str = "image-classification" + task: str = field(default="image-classification", metadata={"include_in_asdict_even_if_is_default": True}) input_schema: ClassVar[Features] = Features({"image": Image()}) label_schema: ClassVar[Features] = Features({"labels": ClassLabel}) image_column: str = "image" diff --git a/src/datasets/tasks/language_modeling.py b/src/datasets/tasks/language_modeling.py index 4f0dfedd149..b2837744fa1 100644 --- a/src/datasets/tasks/language_modeling.py +++ b/src/datasets/tasks/language_modeling.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import ClassVar, Dict from ..features import Features, Value @@ -7,7 +7,7 @@ @dataclass(frozen=True) class LanguageModeling(TaskTemplate): - task: str = "language-modeling" + task: str = field(default="language-modeling", metadata={"include_in_asdict_even_if_is_default": True}) input_schema: ClassVar[Features] = Features({"text": Value("string")}) label_schema: ClassVar[Features] = Features({}) diff --git a/src/datasets/tasks/question_answering.py b/src/datasets/tasks/question_answering.py index 5cce48c7038..349fd541417 100644 --- a/src/datasets/tasks/question_answering.py +++ b/src/datasets/tasks/question_answering.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import ClassVar, Dict from ..features import Features, Sequence, Value @@ -8,7 +8,7 @@ @dataclass(frozen=True) class QuestionAnsweringExtractive(TaskTemplate): # `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization - task: str = "question-answering-extractive" + task: str = field(default="question-answering-extractive", metadata={"include_in_asdict_even_if_is_default": True}) input_schema: ClassVar[Features] = Features({"question": Value("string"), "context": Value("string")}) label_schema: ClassVar[Features] = Features( { diff --git a/src/datasets/tasks/summarization.py b/src/datasets/tasks/summarization.py index 0e99b9d3b7d..a0057b07b4f 100644 --- a/src/datasets/tasks/summarization.py +++ b/src/datasets/tasks/summarization.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import ClassVar, Dict from ..features import Features, Value @@ -8,7 +8,7 @@ @dataclass(frozen=True) class Summarization(TaskTemplate): # `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization - task: str = "summarization" + task: str = field(default="summarization", metadata={"include_in_asdict_even_if_is_default": True}) input_schema: ClassVar[Features] = Features({"text": Value("string")}) label_schema: ClassVar[Features] = Features({"summary": Value("string")}) text_column: str = "text" diff --git a/src/datasets/tasks/text_classification.py b/src/datasets/tasks/text_classification.py index 6d430022384..13584b73e8a 100644 --- a/src/datasets/tasks/text_classification.py +++ b/src/datasets/tasks/text_classification.py @@ -1,5 +1,5 @@ import copy -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import ClassVar, Dict from ..features import ClassLabel, Features, Value @@ -9,7 +9,7 @@ @dataclass(frozen=True) class TextClassification(TaskTemplate): # `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization - task: str = "text-classification" + task: str = field(default="text-classification", metadata={"include_in_asdict_even_if_is_default": True}) input_schema: ClassVar[Features] = Features({"text": Value("string")}) label_schema: ClassVar[Features] = Features({"labels": ClassLabel}) text_column: str = "text" diff --git a/tests/test_tasks.py b/tests/test_tasks.py index a2877ce9185..210e71277b6 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,6 +1,8 @@ from copy import deepcopy from unittest.case import TestCase +import pytest + from datasets.arrow_dataset import Dataset from datasets.features import Audio, ClassLabel, Features, Image, Sequence, Value from datasets.info import DatasetInfo @@ -12,7 +14,9 @@ QuestionAnsweringExtractive, Summarization, TextClassification, + task_template_from_dict, ) +from datasets.utils.py_utils import asdict SAMPLE_QUESTION_ANSWERING_EXTRACTIVE = { @@ -24,6 +28,25 @@ } +@pytest.mark.parametrize( + "task_cls", + [ + AudioClassification, + AutomaticSpeechRecognition, + ImageClassification, + LanguageModeling, + QuestionAnsweringExtractive, + Summarization, + TextClassification, + ], +) +def test_reload_task_from_dict(task_cls): + task = task_cls() + task_dict = asdict(task) + reloaded = task_template_from_dict(task_dict) + assert task == reloaded + + class TestLanguageModeling: def test_column_mapping(self): task = LanguageModeling(text_column="input_text")