Skip to content

Commit

Permalink
TFDeberta
Browse files Browse the repository at this point in the history
moved weights to build and fixed name scope

added missing ,

bug fixes to enable graph mode execution

updated setup.py

fixing typo

fix imports

embedding mask fix

added layer names avoid autmatic incremental names

+XSoftmax

cleanup

added names to layer

disable keras_serializable
Distangled attention output shape hidden_size==None
using symbolic inputs

test for Deberta tf

make style

Update src/transformers/models/deberta/modeling_tf_deberta.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Update src/transformers/models/deberta/modeling_tf_deberta.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Update src/transformers/models/deberta/modeling_tf_deberta.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Update src/transformers/models/deberta/modeling_tf_deberta.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Update src/transformers/models/deberta/modeling_tf_deberta.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Update src/transformers/models/deberta/modeling_tf_deberta.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Update src/transformers/models/deberta/modeling_tf_deberta.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

removed tensorflow-probability

removed blank line
  • Loading branch information
kamalkraj committed Aug 7, 2021
1 parent 7fcee11 commit f84fa5e
Show file tree
Hide file tree
Showing 7 changed files with 1,907 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Expand Up @@ -343,7 +343,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DPR ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DeBERTa |||| ||
| DeBERTa |||| ||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DeBERTa-v2 ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
20 changes: 20 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -1297,6 +1297,17 @@
"TFCTRLPreTrainedModel",
]
)
_import_structure["models.deberta"].extend(
[
"TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFDebertaForMaskedLM",
"TFDebertaForQuestionAnswering",
"TFDebertaForSequenceClassification",
"TFDebertaForTokenClassification",
"TFDebertaModel",
"TFDebertaPreTrainedModel",
]
)
_import_structure["models.distilbert"].extend(
[
"TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -2820,6 +2831,15 @@
TFCTRLModel,
TFCTRLPreTrainedModel,
)
from .models.deberta import (
TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDebertaForMaskedLM,
TFDebertaForQuestionAnswering,
TFDebertaForSequenceClassification,
TFDebertaForTokenClassification,
TFDebertaModel,
TFDebertaPreTrainedModel,
)
from .models.distilbert import (
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDistilBertForMaskedLM,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Expand Up @@ -29,6 +29,7 @@
TF_MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("deberta", "TFDebertaModel"),
("rembert", "TFRemBertModel"),
("roformer", "TFRoFormerModel"),
("convbert", "TFConvBertModel"),
Expand Down Expand Up @@ -144,6 +145,7 @@
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
("deberta", "TFDebertaForMaskedLM"),
("rembert", "TFRemBertForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
Expand Down Expand Up @@ -183,6 +185,7 @@
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("deberta", "TFDebertaForSequenceClassification"),
("rembert", "TFRemBertForSequenceClassification"),
("roformer", "TFRoFormerForSequenceClassification"),
("convbert", "TFConvBertForSequenceClassification"),
Expand Down Expand Up @@ -211,6 +214,7 @@
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("deberta", "TFDebertaForQuestionAnswering"),
("rembert", "TFRemBertForQuestionAnswering"),
("roformer", "TFRoFormerForQuestionAnswering"),
("convbert", "TFConvBertForQuestionAnswering"),
Expand All @@ -234,6 +238,7 @@
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
("deberta", "TFDebertaForTokenClassification"),
("rembert", "TFRemBertForTokenClassification"),
("roformer", "TFRoFormerForTokenClassification"),
("convbert", "TFConvBertForTokenClassification"),
Expand Down
25 changes: 24 additions & 1 deletion src/transformers/models/deberta/__init__.py
Expand Up @@ -18,7 +18,7 @@

from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available
from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available


_import_structure = {
Expand All @@ -40,6 +40,17 @@
"DebertaPreTrainedModel",
]

if is_tf_available():
_import_structure["modeling_tf_deberta"] = [
"TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFDebertaForMaskedLM",
"TFDebertaForQuestionAnswering",
"TFDebertaForSequenceClassification",
"TFDebertaForTokenClassification",
"TFDebertaModel",
"TFDebertaPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
Expand All @@ -59,6 +70,18 @@
DebertaPreTrainedModel,
)

if is_tf_available():
from .modeling_tf_deberta import (
TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDebertaForMaskedLM,
TFDebertaForQuestionAnswering,
TFDebertaForSequenceClassification,
TFDebertaForTokenClassification,
TFDebertaModel,
TFDebertaPreTrainedModel,
)


else:
import sys

Expand Down

0 comments on commit f84fa5e

Please sign in to comment.