Skip to content

Commit

Permalink
Deberta tf (#12972)
Browse files Browse the repository at this point in the history
* TFDeberta

moved weights to build and fixed name scope

added missing ,

bug fixes to enable graph mode execution

updated setup.py

fixing typo

fix imports

embedding mask fix

added layer names avoid autmatic incremental names

+XSoftmax

cleanup

added names to layer

disable keras_serializable
Distangled attention output shape hidden_size==None
using symbolic inputs

test for Deberta tf

make style

Update src/transformers/models/deberta/modeling_tf_deberta.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Update src/transformers/models/deberta/modeling_tf_deberta.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Update src/transformers/models/deberta/modeling_tf_deberta.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Update src/transformers/models/deberta/modeling_tf_deberta.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Update src/transformers/models/deberta/modeling_tf_deberta.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Update src/transformers/models/deberta/modeling_tf_deberta.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Update src/transformers/models/deberta/modeling_tf_deberta.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

removed tensorflow-probability

removed blank line

* removed tf experimental api
+torch_gather tf implementation from @Rocketknight1

* layername DeBERTa --> deberta

* copyright fix

* added docs for TFDeberta & make style

* layer_name change to fix load from pt model

* layer_name change as pt model

* SequenceClassification layername change,
to same as pt model

* switched to keras built-in LayerNormalization

* added `TFDeberta` prefix most layer classes

* updated to tf.Tensor in the docstring
  • Loading branch information
kamalkraj committed Aug 12, 2021
1 parent c4e1586 commit d329b63
Show file tree
Hide file tree
Showing 8 changed files with 1,988 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Expand Up @@ -343,7 +343,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DPR ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DeBERTa |||| ||
| DeBERTa |||| ||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DeBERTa-v2 ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
45 changes: 44 additions & 1 deletion docs/source/model_doc/deberta.rst
Expand Up @@ -38,7 +38,8 @@ the training data performs consistently better on a wide range of NLP tasks, ach
pre-trained models will be made publicly available at https://github.com/microsoft/DeBERTa.*


This model was contributed by `DeBERTa <https://huggingface.co/DeBERTa>`__. The original code can be found `here
This model was contributed by `DeBERTa <https://huggingface.co/DeBERTa>`__. This model TF 2.0 implementation was
contributed by `kamalkraj <https://huggingface.co/kamalkraj>`__ . The original code can be found `here
<https://github.com/microsoft/DeBERTa>`__.


Expand Down Expand Up @@ -103,3 +104,45 @@ DebertaForQuestionAnswering

.. autoclass:: transformers.DebertaForQuestionAnswering
:members: forward


TFDebertaModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDebertaModel
:members: call


TFDebertaPreTrainedModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDebertaPreTrainedModel
:members: call


TFDebertaForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDebertaForMaskedLM
:members: call


TFDebertaForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDebertaForSequenceClassification
:members: call


TFDebertaForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDebertaForTokenClassification
:members: call


TFDebertaForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFDebertaForQuestionAnswering
:members: call
20 changes: 20 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -1297,6 +1297,17 @@
"TFCTRLPreTrainedModel",
]
)
_import_structure["models.deberta"].extend(
[
"TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFDebertaForMaskedLM",
"TFDebertaForQuestionAnswering",
"TFDebertaForSequenceClassification",
"TFDebertaForTokenClassification",
"TFDebertaModel",
"TFDebertaPreTrainedModel",
]
)
_import_structure["models.distilbert"].extend(
[
"TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -2820,6 +2831,15 @@
TFCTRLModel,
TFCTRLPreTrainedModel,
)
from .models.deberta import (
TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDebertaForMaskedLM,
TFDebertaForQuestionAnswering,
TFDebertaForSequenceClassification,
TFDebertaForTokenClassification,
TFDebertaModel,
TFDebertaPreTrainedModel,
)
from .models.distilbert import (
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDistilBertForMaskedLM,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Expand Up @@ -29,6 +29,7 @@
TF_MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("deberta", "TFDebertaModel"),
("rembert", "TFRemBertModel"),
("roformer", "TFRoFormerModel"),
("convbert", "TFConvBertModel"),
Expand Down Expand Up @@ -144,6 +145,7 @@
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
("deberta", "TFDebertaForMaskedLM"),
("rembert", "TFRemBertForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
Expand Down Expand Up @@ -183,6 +185,7 @@
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("deberta", "TFDebertaForSequenceClassification"),
("rembert", "TFRemBertForSequenceClassification"),
("roformer", "TFRoFormerForSequenceClassification"),
("convbert", "TFConvBertForSequenceClassification"),
Expand Down Expand Up @@ -211,6 +214,7 @@
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("deberta", "TFDebertaForQuestionAnswering"),
("rembert", "TFRemBertForQuestionAnswering"),
("roformer", "TFRoFormerForQuestionAnswering"),
("convbert", "TFConvBertForQuestionAnswering"),
Expand All @@ -234,6 +238,7 @@
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
("deberta", "TFDebertaForTokenClassification"),
("rembert", "TFRemBertForTokenClassification"),
("roformer", "TFRoFormerForTokenClassification"),
("convbert", "TFConvBertForTokenClassification"),
Expand Down
25 changes: 24 additions & 1 deletion src/transformers/models/deberta/__init__.py
Expand Up @@ -18,7 +18,7 @@

from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available
from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available


_import_structure = {
Expand All @@ -40,6 +40,17 @@
"DebertaPreTrainedModel",
]

if is_tf_available():
_import_structure["modeling_tf_deberta"] = [
"TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFDebertaForMaskedLM",
"TFDebertaForQuestionAnswering",
"TFDebertaForSequenceClassification",
"TFDebertaForTokenClassification",
"TFDebertaModel",
"TFDebertaPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
Expand All @@ -59,6 +70,18 @@
DebertaPreTrainedModel,
)

if is_tf_available():
from .modeling_tf_deberta import (
TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDebertaForMaskedLM,
TFDebertaForQuestionAnswering,
TFDebertaForSequenceClassification,
TFDebertaForTokenClassification,
TFDebertaModel,
TFDebertaPreTrainedModel,
)


else:
import sys

Expand Down

0 comments on commit d329b63

Please sign in to comment.