Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deberta tf #12972

Merged
merged 11 commits into from Aug 12, 2021
2 changes: 1 addition & 1 deletion docs/source/index.rst
Expand Up @@ -343,7 +343,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DPR | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DeBERTa | ✅ | ✅ | ✅ | | ❌ |
| DeBERTa | ✅ | ✅ | ✅ | | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DeBERTa-v2 | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
45 changes: 44 additions & 1 deletion docs/source/model_doc/deberta.rst
Expand Up @@ -38,7 +38,8 @@ the training data performs consistently better on a wide range of NLP tasks, ach
pre-trained models will be made publicly available at https://github.com/microsoft/DeBERTa.*


This model was contributed by `DeBERTa <https://huggingface.co/DeBERTa>`__. The original code can be found `here
This model was contributed by `DeBERTa <https://huggingface.co/DeBERTa>`__. This model TF 2.0 implementation was
contributed by `kamalkraj <https://huggingface.co/kamalkraj>`__ . The original code can be found `here
<https://github.com/microsoft/DeBERTa>`__.


Expand Down Expand Up @@ -103,3 +104,45 @@ DebertaForQuestionAnswering

.. autoclass:: transformers.DebertaForQuestionAnswering
:members: forward


TFDebertaModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDebertaModel
:members: call


TFDebertaPreTrainedModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDebertaPreTrainedModel
:members: call


TFDebertaForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDebertaForMaskedLM
:members: call


TFDebertaForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDebertaForSequenceClassification
:members: call


TFDebertaForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDebertaForTokenClassification
:members: call


TFDebertaForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDebertaForQuestionAnswering
:members: call
20 changes: 20 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -1297,6 +1297,17 @@
"TFCTRLPreTrainedModel",
]
)
_import_structure["models.deberta"].extend(
[
"TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFDebertaForMaskedLM",
"TFDebertaForQuestionAnswering",
"TFDebertaForSequenceClassification",
"TFDebertaForTokenClassification",
"TFDebertaModel",
"TFDebertaPreTrainedModel",
]
)
_import_structure["models.distilbert"].extend(
[
"TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -2820,6 +2831,15 @@
TFCTRLModel,
TFCTRLPreTrainedModel,
)
from .models.deberta import (
TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDebertaForMaskedLM,
TFDebertaForQuestionAnswering,
TFDebertaForSequenceClassification,
TFDebertaForTokenClassification,
TFDebertaModel,
TFDebertaPreTrainedModel,
)
from .models.distilbert import (
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDistilBertForMaskedLM,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Expand Up @@ -29,6 +29,7 @@
TF_MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("deberta", "TFDebertaModel"),
("rembert", "TFRemBertModel"),
("roformer", "TFRoFormerModel"),
("convbert", "TFConvBertModel"),
Expand Down Expand Up @@ -144,6 +145,7 @@
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
("deberta", "TFDebertaForMaskedLM"),
("rembert", "TFRemBertForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
Expand Down Expand Up @@ -183,6 +185,7 @@
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("deberta", "TFDebertaForSequenceClassification"),
("rembert", "TFRemBertForSequenceClassification"),
("roformer", "TFRoFormerForSequenceClassification"),
("convbert", "TFConvBertForSequenceClassification"),
Expand Down Expand Up @@ -211,6 +214,7 @@
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("deberta", "TFDebertaForQuestionAnswering"),
("rembert", "TFRemBertForQuestionAnswering"),
("roformer", "TFRoFormerForQuestionAnswering"),
("convbert", "TFConvBertForQuestionAnswering"),
Expand All @@ -234,6 +238,7 @@
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
("deberta", "TFDebertaForTokenClassification"),
("rembert", "TFRemBertForTokenClassification"),
("roformer", "TFRoFormerForTokenClassification"),
("convbert", "TFConvBertForTokenClassification"),
Expand Down
25 changes: 24 additions & 1 deletion src/transformers/models/deberta/__init__.py
Expand Up @@ -18,7 +18,7 @@

from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available
from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available


_import_structure = {
Expand All @@ -40,6 +40,17 @@
"DebertaPreTrainedModel",
]

if is_tf_available():
_import_structure["modeling_tf_deberta"] = [
"TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFDebertaForMaskedLM",
"TFDebertaForQuestionAnswering",
"TFDebertaForSequenceClassification",
"TFDebertaForTokenClassification",
"TFDebertaModel",
"TFDebertaPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
Expand All @@ -59,6 +70,18 @@
DebertaPreTrainedModel,
)

if is_tf_available():
from .modeling_tf_deberta import (
TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDebertaForMaskedLM,
TFDebertaForQuestionAnswering,
TFDebertaForSequenceClassification,
TFDebertaForTokenClassification,
TFDebertaModel,
TFDebertaPreTrainedModel,
)


else:
import sys

Expand Down