Skip to content

Commit

Permalink
fix task template reload from dict
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Oct 12, 2022
1 parent dc4c764 commit 884bca7
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/datasets/tasks/audio_classificiation.py
@@ -1,5 +1,5 @@
import copy
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import ClassVar, Dict

from ..features import Audio, ClassLabel, Features
Expand All @@ -8,7 +8,7 @@

@dataclass(frozen=True)
class AudioClassification(TaskTemplate):
task: str = "audio-classification"
task: str = field(default="audio-classification", metadata={"include_in_asdict_even_if_is_default": True})
input_schema: ClassVar[Features] = Features({"audio": Audio()})
label_schema: ClassVar[Features] = Features({"labels": ClassLabel})
audio_column: str = "audio"
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/tasks/automatic_speech_recognition.py
@@ -1,5 +1,5 @@
import copy
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import ClassVar, Dict

from ..features import Audio, Features, Value
Expand All @@ -8,7 +8,7 @@

@dataclass(frozen=True)
class AutomaticSpeechRecognition(TaskTemplate):
task: str = "automatic-speech-recognition"
task: str = field(default="automatic-speech-recognition", metadata={"include_in_asdict_even_if_is_default": True})
input_schema: ClassVar[Features] = Features({"audio": Audio()})
label_schema: ClassVar[Features] = Features({"transcription": Value("string")})
audio_column: str = "audio"
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/tasks/image_classification.py
@@ -1,5 +1,5 @@
import copy
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import ClassVar, Dict

from ..features import ClassLabel, Features, Image
Expand All @@ -8,7 +8,7 @@

@dataclass(frozen=True)
class ImageClassification(TaskTemplate):
task: str = "image-classification"
task: str = field(default="image-classification", metadata={"include_in_asdict_even_if_is_default": True})
input_schema: ClassVar[Features] = Features({"image": Image()})
label_schema: ClassVar[Features] = Features({"labels": ClassLabel})
image_column: str = "image"
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/tasks/language_modeling.py
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import ClassVar, Dict

from ..features import Features, Value
Expand All @@ -7,7 +7,7 @@

@dataclass(frozen=True)
class LanguageModeling(TaskTemplate):
task: str = "language-modeling"
task: str = field(default="language-modeling", metadata={"include_in_asdict_even_if_is_default": True})

input_schema: ClassVar[Features] = Features({"text": Value("string")})
label_schema: ClassVar[Features] = Features({})
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/tasks/question_answering.py
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import ClassVar, Dict

from ..features import Features, Sequence, Value
Expand All @@ -8,7 +8,7 @@
@dataclass(frozen=True)
class QuestionAnsweringExtractive(TaskTemplate):
# `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization
task: str = "question-answering-extractive"
task: str = field(default="question-answering-extractive", metadata={"include_in_asdict_even_if_is_default": True})
input_schema: ClassVar[Features] = Features({"question": Value("string"), "context": Value("string")})
label_schema: ClassVar[Features] = Features(
{
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/tasks/summarization.py
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import ClassVar, Dict

from ..features import Features, Value
Expand All @@ -8,7 +8,7 @@
@dataclass(frozen=True)
class Summarization(TaskTemplate):
# `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization
task: str = "summarization"
task: str = field(default="summarization", metadata={"include_in_asdict_even_if_is_default": True})
input_schema: ClassVar[Features] = Features({"text": Value("string")})
label_schema: ClassVar[Features] = Features({"summary": Value("string")})
text_column: str = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/tasks/text_classification.py
@@ -1,5 +1,5 @@
import copy
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import ClassVar, Dict

from ..features import ClassLabel, Features, Value
Expand All @@ -9,7 +9,7 @@
@dataclass(frozen=True)
class TextClassification(TaskTemplate):
# `task` is not a ClassVar since we want it to be part of the `asdict` output for JSON serialization
task: str = "text-classification"
task: str = field(default="text-classification", metadata={"include_in_asdict_even_if_is_default": True})
input_schema: ClassVar[Features] = Features({"text": Value("string")})
label_schema: ClassVar[Features] = Features({"labels": ClassLabel})
text_column: str = "text"
Expand Down
23 changes: 23 additions & 0 deletions tests/test_tasks.py
@@ -1,6 +1,8 @@
from copy import deepcopy
from unittest.case import TestCase

import pytest

from datasets.arrow_dataset import Dataset
from datasets.features import Audio, ClassLabel, Features, Image, Sequence, Value
from datasets.info import DatasetInfo
Expand All @@ -12,7 +14,9 @@
QuestionAnsweringExtractive,
Summarization,
TextClassification,
task_template_from_dict,
)
from datasets.utils.py_utils import asdict


SAMPLE_QUESTION_ANSWERING_EXTRACTIVE = {
Expand All @@ -24,6 +28,25 @@
}


@pytest.mark.parametrize(
"task_cls",
[
AudioClassification,
AutomaticSpeechRecognition,
ImageClassification,
LanguageModeling,
QuestionAnsweringExtractive,
Summarization,
TextClassification,
],
)
def test_reload_task_from_dict(task_cls):
task = task_cls()
task_dict = asdict(task)
reloaded = task_template_from_dict(task_dict)
assert task == reloaded


class TestLanguageModeling:
def test_column_mapping(self):
task = LanguageModeling(text_column="input_text")
Expand Down

1 comment on commit 884bca7

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==6.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.008122 / 0.011353 (-0.003231) 0.004458 / 0.011008 (-0.006550) 0.098941 / 0.038508 (0.060433) 0.028972 / 0.023109 (0.005863) 0.303907 / 0.275898 (0.028009) 0.358410 / 0.323480 (0.034930) 0.006608 / 0.007986 (-0.001377) 0.004194 / 0.004328 (-0.000134) 0.074202 / 0.004250 (0.069952) 0.033914 / 0.037052 (-0.003138) 0.314346 / 0.258489 (0.055857) 0.360516 / 0.293841 (0.066675) 0.038100 / 0.128546 (-0.090446) 0.014245 / 0.075646 (-0.061401) 0.322650 / 0.419271 (-0.096621) 0.044281 / 0.043533 (0.000748) 0.311873 / 0.255139 (0.056734) 0.333007 / 0.283200 (0.049808) 0.083449 / 0.141683 (-0.058234) 1.527138 / 1.452155 (0.074984) 1.528919 / 1.492716 (0.036203)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.012930 / 0.018006 (-0.005076) 0.410717 / 0.000490 (0.410228) 0.004409 / 0.000200 (0.004209) 0.000077 / 0.000054 (0.000023)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.021306 / 0.037411 (-0.016105) 0.091163 / 0.014526 (0.076637) 0.104332 / 0.176557 (-0.072225) 0.145030 / 0.737135 (-0.592105) 0.105091 / 0.296338 (-0.191248)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.408809 / 0.215209 (0.193600) 4.085681 / 2.077655 (2.008027) 1.867399 / 1.504120 (0.363279) 1.666169 / 1.541195 (0.124975) 1.671474 / 1.468490 (0.202984) 0.681137 / 4.584777 (-3.903640) 3.320371 / 3.745712 (-0.425341) 1.809087 / 5.269862 (-3.460774) 1.135986 / 4.565676 (-3.429690) 0.081002 / 0.424275 (-0.343273) 0.011350 / 0.007607 (0.003743) 0.512846 / 0.226044 (0.286802) 5.146968 / 2.268929 (2.878040) 2.250372 / 55.444624 (-53.194253) 1.923299 / 6.876477 (-4.953177) 1.991251 / 2.142072 (-0.150821) 0.802524 / 4.805227 (-4.002703) 0.147459 / 6.500664 (-6.353205) 0.063121 / 0.075469 (-0.012349)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.523131 / 1.841788 (-0.318657) 12.162836 / 8.074308 (4.088528) 25.662715 / 10.191392 (15.471323) 0.874663 / 0.680424 (0.194240) 0.621430 / 0.534201 (0.087229) 0.384049 / 0.579283 (-0.195234) 0.394382 / 0.434364 (-0.039982) 0.242860 / 0.540337 (-0.297477) 0.249149 / 1.386936 (-1.137787)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.006396 / 0.011353 (-0.004956) 0.004522 / 0.011008 (-0.006487) 0.097461 / 0.038508 (0.058953) 0.027091 / 0.023109 (0.003982) 0.413838 / 0.275898 (0.137940) 0.442686 / 0.323480 (0.119206) 0.004883 / 0.007986 (-0.003103) 0.003366 / 0.004328 (-0.000962) 0.074644 / 0.004250 (0.070394) 0.032819 / 0.037052 (-0.004234) 0.419074 / 0.258489 (0.160585) 0.452295 / 0.293841 (0.158454) 0.033874 / 0.128546 (-0.094673) 0.011821 / 0.075646 (-0.063826) 0.323866 / 0.419271 (-0.095405) 0.042031 / 0.043533 (-0.001502) 0.413538 / 0.255139 (0.158399) 0.446651 / 0.283200 (0.163451) 0.085831 / 0.141683 (-0.055851) 1.606069 / 1.452155 (0.153914) 1.582235 / 1.492716 (0.089519)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.218104 / 0.018006 (0.200098) 0.397114 / 0.000490 (0.396625) 0.001011 / 0.000200 (0.000811) 0.000067 / 0.000054 (0.000013)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.020219 / 0.037411 (-0.017192) 0.090841 / 0.014526 (0.076316) 0.104556 / 0.176557 (-0.072001) 0.146130 / 0.737135 (-0.591006) 0.104374 / 0.296338 (-0.191965)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.473003 / 0.215209 (0.257794) 4.723688 / 2.077655 (2.646034) 2.431309 / 1.504120 (0.927189) 2.241957 / 1.541195 (0.700762) 2.249288 / 1.468490 (0.780798) 0.697116 / 4.584777 (-3.887661) 3.330119 / 3.745712 (-0.415593) 1.810640 / 5.269862 (-3.459221) 1.143017 / 4.565676 (-3.422660) 0.081466 / 0.424275 (-0.342809) 0.011648 / 0.007607 (0.004041) 0.577748 / 0.226044 (0.351703) 5.788931 / 2.268929 (3.520002) 2.834682 / 55.444624 (-52.609942) 2.506364 / 6.876477 (-4.370113) 2.549693 / 2.142072 (0.407620) 0.811939 / 4.805227 (-3.993289) 0.148741 / 6.500664 (-6.351923) 0.064590 / 0.075469 (-0.010879)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.576874 / 1.841788 (-0.264914) 12.308948 / 8.074308 (4.234640) 12.494239 / 10.191392 (2.302847) 0.906729 / 0.680424 (0.226305) 0.642011 / 0.534201 (0.107810) 0.371238 / 0.579283 (-0.208045) 0.377041 / 0.434364 (-0.057323) 0.214945 / 0.540337 (-0.325392) 0.218113 / 1.386936 (-1.168823)

CML watermark

Please sign in to comment.