Skip to content

Commit

Permalink
Add Data2Vec for Vision in TF (huggingface#17008)
Browse files Browse the repository at this point in the history
* add utilities till TFData2VecVisionLayer.

* chore: pass window_size to attention layer.

* feat: add TFData2VecVisionRelativePositionBias.

* feat: initial implementation ready for tf data2vec.

* fix: relative position bias index, table to be fixed.

* chore: implementation added, tests remaining.

* add: tests, other PR files.

* fix: code quality.

* fix: import structure in init.

* chore: run make fix-copies.

* chore: address PR feedback (round I).

* chore: styling nit.

* fix: tests due to removal of to_2tuple().

* chore: rebase with upstream main and move the test.

* Update src/transformers/models/auto/modeling_tf_auto.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/auto/modeling_tf_auto.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix: layer call.

* chore: remove from_pt=True and rerun test.

* chore: remove cast and tf.divide.

* chore: minor edits to the test script.

* Update src/transformers/models/data2vec/modeling_tf_data2vec_vision.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* fix: expand() on TF tensors with broadcast_to().

* fix: test import.

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
  • Loading branch information
3 people authored and nandwalritik committed May 4, 2022
1 parent b13cc89 commit 2eff2e4
Show file tree
Hide file tree
Showing 8 changed files with 1,511 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Expand Up @@ -191,7 +191,7 @@ Flax), PyTorch, and/or TensorFlow.
| CTRL | | | | | |
| Data2VecAudio | | | | | |
| Data2VecText | | | | | |
| Data2VecVision | | | | | |
| Data2VecVision | | | | | |
| DeBERTa | | | | | |
| DeBERTa-v2 | | | | | |
| Decision Transformer | | | | | |
Expand Down
16 changes: 14 additions & 2 deletions docs/source/en/model_doc/data2vec.mdx
Expand Up @@ -38,9 +38,11 @@ Tips:
- For Data2VecText, preprocessing is identical to [`RobertaModel`], including tokenization.
- For Data2VecVision, preprocessing is identical to [`BeitModel`], including feature extraction.

This model was contributed by [edugp](https://huggingface.co/edugp) and [patrickvonplaten](https://huggingface.co/patrickvonplaten)
This model was contributed by [edugp](https://huggingface.co/edugp) and [patrickvonplaten](https://huggingface.co/patrickvonplaten).
[sayakpaul](https://github.com/sayakpaul) contributed Data2Vec for vision in TensorFlow.

The original code can be found [here](https://github.com/pytorch/fairseq/tree/main/examples/data2vec).
The original code (for NLP and Speech) can be found [here](https://github.com/pytorch/fairseq/tree/main/examples/data2vec).
The original code for vision can be found [here](https://github.com/facebookresearch/data2vec_vision/tree/main/beit).


## Data2VecTextConfig
Expand Down Expand Up @@ -130,3 +132,13 @@ The original code can be found [here](https://github.com/pytorch/fairseq/tree/ma

[[autodoc]] Data2VecVisionForSemanticSegmentation
- forward

## TFData2VecVisionModel

[[autodoc]] TFData2VecVisionModel
- call

## TFData2VecVisionForImageClassification

[[autodoc]] TFData2VecVisionForImageClassification
- call
12 changes: 12 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -1878,6 +1878,13 @@
"TFCTRLPreTrainedModel",
]
)
_import_structure["models.data2vec"].extend(
[
"TFData2VecVisionForImageClassification",
"TFData2VecVisionModel",
"TFData2VecVisionPreTrainedModel",
]
)
_import_structure["models.deberta"].extend(
[
"TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -4029,6 +4036,11 @@
TFCTRLModel,
TFCTRLPreTrainedModel,
)
from .models.data2vec import (
TFData2VecVisionForImageClassification,
TFData2VecVisionModel,
TFData2VecVisionPreTrainedModel,
)
from .models.deberta import (
TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDebertaForMaskedLM,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Expand Up @@ -37,6 +37,7 @@
("roformer", "TFRoFormerModel"),
("convbert", "TFConvBertModel"),
("convnext", "TFConvNextModel"),
("data2vec-vision", "TFData2VecVisionModel"),
("led", "TFLEDModel"),
("lxmert", "TFLxmertModel"),
("mt5", "TFMT5Model"),
Expand Down Expand Up @@ -163,6 +164,7 @@
# Model for Image-classsification
("vit", "TFViTForImageClassification"),
("convnext", "TFConvNextForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
]
)

Expand Down
15 changes: 15 additions & 0 deletions src/transformers/models/data2vec/__init__.py
Expand Up @@ -18,6 +18,8 @@

from typing import TYPE_CHECKING

from transformers.utils.import_utils import is_tf_available

from ...utils import _LazyModule, is_torch_available


Expand Down Expand Up @@ -68,6 +70,13 @@
"Data2VecVisionPreTrainedModel",
]

if is_tf_available():
_import_structure["modeling_tf_data2vec_vision"] = [
"TFData2VecVisionForImageClassification",
"TFData2VecVisionModel",
"TFData2VecVisionPreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_data2vec_audio import DATA2VEC_AUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP, Data2VecAudioConfig
from .configuration_data2vec_text import (
Expand Down Expand Up @@ -110,6 +119,12 @@
Data2VecVisionModel,
Data2VecVisionPreTrainedModel,
)
if is_tf_available():
from .modeling_tf_data2vec_vision import (
TFData2VecVisionForImageClassification,
TFData2VecVisionModel,
TFData2VecVisionPreTrainedModel,
)

else:
import sys
Expand Down

0 comments on commit 2eff2e4

Please sign in to comment.