Skip to content

Commit

Permalink
distilbert-flax (#13324)
Browse files Browse the repository at this point in the history
* distilbert-flax

* added missing self

* docs fix

* removed tied kernal extra init

* updated docs

* x -> hidden states

* removed head_mask

* removed from_pt, +FLAX

* updated year
  • Loading branch information
kamalkraj committed Aug 30, 2021
1 parent 0197746 commit 774760e
Show file tree
Hide file tree
Showing 8 changed files with 1,205 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Expand Up @@ -357,7 +357,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DETR ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DistilBERT ||||| |
| DistilBERT ||||| |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DPR ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
47 changes: 45 additions & 2 deletions docs/source/model_doc/distilbert.rst
Expand Up @@ -44,8 +44,9 @@ Tips:
- DistilBERT doesn't have options to select the input positions (:obj:`position_ids` input). This could be added if
necessary though, just let us know if you need this option.

This model was contributed by `victorsanh <https://huggingface.co/victorsanh>`__. The original code can be found
:prefix_link:`here <examples/research-projects/distillation>`.
This model was contributed by `victorsanh <https://huggingface.co/victorsanh>`__. This model jax version was
contributed by `kamalkraj <https://huggingface.co/kamalkraj>`__. The original code can be found :prefix_link:`here
<examples/research-projects/distillation>`.


DistilBertConfig
Expand Down Expand Up @@ -152,3 +153,45 @@ TFDistilBertForQuestionAnswering

.. autoclass:: transformers.TFDistilBertForQuestionAnswering
:members: call


FlaxDistilBertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxDistilBertModel
:members: __call__


FlaxDistilBertForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxDistilBertForMaskedLM
:members: __call__


FlaxDistilBertForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxDistilBertForSequenceClassification
:members: __call__


FlaxDistilBertForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxDistilBertForMultipleChoice
:members: __call__


FlaxDistilBertForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxDistilBertForTokenClassification
:members: __call__


FlaxDistilBertForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxDistilBertForQuestionAnswering
:members: __call__
20 changes: 20 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -1712,6 +1712,17 @@
"FlaxCLIPVisionPreTrainedModel",
]
)
_import_structure["models.distilbert"].extend(
[
"FlaxDistilBertForMaskedLM",
"FlaxDistilBertForMultipleChoice",
"FlaxDistilBertForQuestionAnswering",
"FlaxDistilBertForSequenceClassification",
"FlaxDistilBertForTokenClassification",
"FlaxDistilBertModel",
"FlaxDistilBertPreTrainedModel",
]
)
_import_structure["models.electra"].extend(
[
"FlaxElectraForMaskedLM",
Expand Down Expand Up @@ -3201,6 +3212,15 @@
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
)
from .models.distilbert import (
FlaxDistilBertForMaskedLM,
FlaxDistilBertForMultipleChoice,
FlaxDistilBertForQuestionAnswering,
FlaxDistilBertForSequenceClassification,
FlaxDistilBertForTokenClassification,
FlaxDistilBertModel,
FlaxDistilBertPreTrainedModel,
)
from .models.electra import (
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Expand Up @@ -28,6 +28,7 @@
FLAX_MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("distilbert", "FlaxDistilBertModel"),
("roberta", "FlaxRobertaModel"),
("bert", "FlaxBertModel"),
("big_bird", "FlaxBigBirdModel"),
Expand Down Expand Up @@ -63,6 +64,7 @@
FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
("distilbert", "FlaxDistilBertForMaskedLM"),
("roberta", "FlaxRobertaForMaskedLM"),
("bert", "FlaxBertForMaskedLM"),
("big_bird", "FlaxBigBirdForMaskedLM"),
Expand Down Expand Up @@ -101,6 +103,7 @@
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("distilbert", "FlaxDistilBertForSequenceClassification"),
("roberta", "FlaxRobertaForSequenceClassification"),
("bert", "FlaxBertForSequenceClassification"),
("big_bird", "FlaxBigBirdForSequenceClassification"),
Expand All @@ -113,6 +116,7 @@
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("distilbert", "FlaxDistilBertForQuestionAnswering"),
("roberta", "FlaxRobertaForQuestionAnswering"),
("bert", "FlaxBertForQuestionAnswering"),
("big_bird", "FlaxBigBirdForQuestionAnswering"),
Expand All @@ -125,6 +129,7 @@
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
("distilbert", "FlaxDistilBertForTokenClassification"),
("roberta", "FlaxRobertaForTokenClassification"),
("bert", "FlaxBertForTokenClassification"),
("big_bird", "FlaxBigBirdForTokenClassification"),
Expand All @@ -135,6 +140,7 @@
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
("distilbert", "FlaxDistilBertForMultipleChoice"),
("roberta", "FlaxRobertaForMultipleChoice"),
("bert", "FlaxBertForMultipleChoice"),
("big_bird", "FlaxBigBirdForMultipleChoice"),
Expand Down
24 changes: 23 additions & 1 deletion src/transformers/models/distilbert/__init__.py
Expand Up @@ -18,7 +18,7 @@

from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available


_import_structure = {
Expand Down Expand Up @@ -58,6 +58,17 @@
"TFDistilBertPreTrainedModel",
]

if is_flax_available():
_import_structure["modeling_flax_distilbert"] = [
"FlaxDistilBertForMaskedLM",
"FlaxDistilBertForMultipleChoice",
"FlaxDistilBertForQuestionAnswering",
"FlaxDistilBertForSequenceClassification",
"FlaxDistilBertForTokenClassification",
"FlaxDistilBertModel",
"FlaxDistilBertPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_distilbert import (
Expand Down Expand Up @@ -95,6 +106,17 @@
TFDistilBertPreTrainedModel,
)

if is_flax_available():
from .modeling_flax_distilbert import (
FlaxDistilBertForMaskedLM,
FlaxDistilBertForMultipleChoice,
FlaxDistilBertForQuestionAnswering,
FlaxDistilBertForSequenceClassification,
FlaxDistilBertForTokenClassification,
FlaxDistilBertModel,
FlaxDistilBertPreTrainedModel,
)

else:
import sys

Expand Down

0 comments on commit 774760e

Please sign in to comment.