Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF/Numpy variants for all DataCollator classes #13105

Merged
merged 36 commits into from Aug 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
6775d6d
Adding a TF variant of the DataCollatorForTokenClassification to get …
Rocketknight1 Aug 12, 2021
0ae2587
Added a Numpy variant and a post_init check to fail early if a missin…
Rocketknight1 Aug 12, 2021
d577ac9
Fixed call to Numpy variant
Rocketknight1 Aug 12, 2021
346a646
Added a couple more of the collators
Rocketknight1 Aug 12, 2021
33dbb29
Update src/transformers/data/data_collator.py
Rocketknight1 Aug 13, 2021
70168b5
Fixes, style pass, finished DataCollatorForSeqToSeq
Rocketknight1 Aug 13, 2021
050208b
Added all the LanguageModeling DataCollators, except SOP and Permutat…
Rocketknight1 Aug 18, 2021
ce8dcc3
Adding DataCollatorForPermutationLanguageModeling
Rocketknight1 Aug 19, 2021
dda8906
Style pass
Rocketknight1 Aug 19, 2021
f81068b
Add missing `__call__` for PLM
Rocketknight1 Aug 19, 2021
abf8325
Remove `post_init` checks for frameworks because the imports inside t…
Rocketknight1 Aug 19, 2021
52cf828
Remove unused imports
Rocketknight1 Aug 19, 2021
4b823f9
First attempt at some TF tests
Rocketknight1 Aug 19, 2021
6cf757a
A second attempt to make any of those tests actually work
Rocketknight1 Aug 19, 2021
6db05c6
TF tests, round three
Rocketknight1 Aug 19, 2021
212b6e0
TF tests, round four
Rocketknight1 Aug 20, 2021
dbb3848
TF tests, round five
Rocketknight1 Aug 20, 2021
2d59e3d
TF tests, all enabled!
Rocketknight1 Aug 23, 2021
eed3c88
Style pass
Rocketknight1 Aug 23, 2021
3616216
Merging tests into `test_data_collator.py`
Rocketknight1 Aug 23, 2021
f413717
Merging tests into `test_data_collator.py`
Rocketknight1 Aug 23, 2021
71cb01c
Fixing up test imports
Rocketknight1 Aug 23, 2021
984496d
Fixing up test imports
Rocketknight1 Aug 23, 2021
d9b16f7
Trying shuffling the conditionals around
Rocketknight1 Aug 24, 2021
e1975bd
Commenting out non-functional old tests
Rocketknight1 Aug 24, 2021
1616c6f
Completed all tests for all three frameworks
Rocketknight1 Aug 25, 2021
e57a75b
Style pass
Rocketknight1 Aug 25, 2021
9b654d9
Fixed test typo
Rocketknight1 Aug 25, 2021
95b0ddc
Style pass
Rocketknight1 Aug 25, 2021
472f74b
Move standard `__call__` method to mixin
Rocketknight1 Aug 30, 2021
5791028
Rearranged imports for `test_data_collator`
Rocketknight1 Aug 30, 2021
582a95c
Fix data collator typo "torch" -> "pt"
Rocketknight1 Aug 30, 2021
4b9cfb5
Fixed the most embarrassingly obvious bug
Rocketknight1 Aug 30, 2021
dc58c69
Update src/transformers/data/data_collator.py
Rocketknight1 Aug 30, 2021
375c316
Renaming mixin
Rocketknight1 Aug 30, 2021
28606ec
Updating docs
Rocketknight1 Aug 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/main_classes/data_collator.rst
Expand Up @@ -54,18 +54,18 @@ DataCollatorForLanguageModeling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.data.data_collator.DataCollatorForLanguageModeling
:members: mask_tokens
:members: numpy_mask_tokens, tf_mask_tokens, torch_mask_tokens


DataCollatorForWholeWordMask
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.data.data_collator.DataCollatorForWholeWordMask
:members: mask_tokens
:members: numpy_mask_tokens, tf_mask_tokens, torch_mask_tokens


DataCollatorForPermutationLanguageModeling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.data.data_collator.DataCollatorForPermutationLanguageModeling
:members: mask_tokens
:members: numpy_mask_tokens, tf_mask_tokens, torch_mask_tokens
44 changes: 22 additions & 22 deletions src/transformers/__init__.py
Expand Up @@ -81,6 +81,17 @@
"xnli_processors",
"xnli_tasks_num_labels",
],
"data.data_collator": [
"DataCollator",
"DataCollatorForLanguageModeling",
"DataCollatorForPermutationLanguageModeling",
"DataCollatorForSeq2Seq",
"DataCollatorForSOP",
"DataCollatorForTokenClassification",
"DataCollatorForWholeWordMask",
"DataCollatorWithPadding",
"default_data_collator",
],
"feature_extraction_sequence_utils": ["BatchFeature", "SequenceFeatureExtractor"],
"file_utils": [
"CONFIG_NAME",
Expand Down Expand Up @@ -460,17 +471,6 @@
if is_torch_available():
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["data.data_collator"] = [
"DataCollator",
"DataCollatorForLanguageModeling",
"DataCollatorForPermutationLanguageModeling",
"DataCollatorForSeq2Seq",
"DataCollatorForSOP",
"DataCollatorForTokenClassification",
"DataCollatorForWholeWordMask",
"DataCollatorWithPadding",
"default_data_collator",
]
_import_structure["data.datasets"] = [
"GlueDataset",
"GlueDataTrainingArguments",
Expand Down Expand Up @@ -1818,6 +1818,17 @@
xnli_processors,
xnli_tasks_num_labels,
)
from .data.data_collator import (
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSeq2Seq,
DataCollatorForSOP,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithPadding,
default_data_collator,
)

# Feature Extractor
from .feature_extraction_utils import BatchFeature, SequenceFeatureExtractor
Expand Down Expand Up @@ -2162,17 +2173,6 @@
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .data.data_collator import (
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSeq2Seq,
DataCollatorForSOP,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithPadding,
default_data_collator,
)
from .data.datasets import (
GlueDataset,
GlueDataTrainingArguments,
Expand Down