Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

distilbert-flax #13324

Merged
merged 9 commits into from Aug 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/index.rst
Expand Up @@ -350,7 +350,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DETR | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DistilBERT | ✅ | ✅ | ✅ | ✅ | |
| DistilBERT | ✅ | ✅ | ✅ | ✅ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| DPR | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
47 changes: 45 additions & 2 deletions docs/source/model_doc/distilbert.rst
Expand Up @@ -44,8 +44,9 @@ Tips:
- DistilBERT doesn't have options to select the input positions (:obj:`position_ids` input). This could be added if
necessary though, just let us know if you need this option.

This model was contributed by `victorsanh <https://huggingface.co/victorsanh>`__. The original code can be found
:prefix_link:`here <examples/research-projects/distillation>`.
This model was contributed by `victorsanh <https://huggingface.co/victorsanh>`__. This model jax version was
contributed by `kamalkraj <https://huggingface.co/kamalkraj>`__. The original code can be found :prefix_link:`here
<examples/research-projects/distillation>`.


DistilBertConfig
Expand Down Expand Up @@ -152,3 +153,45 @@ TFDistilBertForQuestionAnswering

.. autoclass:: transformers.TFDistilBertForQuestionAnswering
:members: call


FlaxDistilBertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxDistilBertModel
:members: __call__


FlaxDistilBertForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxDistilBertForMaskedLM
:members: __call__


FlaxDistilBertForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxDistilBertForSequenceClassification
:members: __call__


FlaxDistilBertForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxDistilBertForMultipleChoice
:members: __call__


FlaxDistilBertForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxDistilBertForTokenClassification
:members: __call__


FlaxDistilBertForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxDistilBertForQuestionAnswering
:members: __call__
20 changes: 20 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -1693,6 +1693,17 @@
"FlaxCLIPVisionPreTrainedModel",
]
)
_import_structure["models.distilbert"].extend(
[
"FlaxDistilBertForMaskedLM",
"FlaxDistilBertForMultipleChoice",
"FlaxDistilBertForQuestionAnswering",
"FlaxDistilBertForSequenceClassification",
"FlaxDistilBertForTokenClassification",
"FlaxDistilBertModel",
"FlaxDistilBertPreTrainedModel",
]
)
_import_structure["models.electra"].extend(
[
"FlaxElectraForMaskedLM",
Expand Down Expand Up @@ -3166,6 +3177,15 @@
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
)
from .models.distilbert import (
FlaxDistilBertForMaskedLM,
FlaxDistilBertForMultipleChoice,
FlaxDistilBertForQuestionAnswering,
FlaxDistilBertForSequenceClassification,
FlaxDistilBertForTokenClassification,
FlaxDistilBertModel,
FlaxDistilBertPreTrainedModel,
)
from .models.electra import (
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Expand Up @@ -28,6 +28,7 @@
FLAX_MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("distilbert", "FlaxDistilBertModel"),
("roberta", "FlaxRobertaModel"),
("bert", "FlaxBertModel"),
("big_bird", "FlaxBigBirdModel"),
Expand Down Expand Up @@ -63,6 +64,7 @@
FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
("distilbert", "FlaxDistilBertForMaskedLM"),
("roberta", "FlaxRobertaForMaskedLM"),
("bert", "FlaxBertForMaskedLM"),
("big_bird", "FlaxBigBirdForMaskedLM"),
Expand Down Expand Up @@ -101,6 +103,7 @@
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("distilbert", "FlaxDistilBertForSequenceClassification"),
("roberta", "FlaxRobertaForSequenceClassification"),
("bert", "FlaxBertForSequenceClassification"),
("big_bird", "FlaxBigBirdForSequenceClassification"),
Expand All @@ -113,6 +116,7 @@
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("distilbert", "FlaxDistilBertForQuestionAnswering"),
("roberta", "FlaxRobertaForQuestionAnswering"),
("bert", "FlaxBertForQuestionAnswering"),
("big_bird", "FlaxBigBirdForQuestionAnswering"),
Expand All @@ -125,6 +129,7 @@
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
("distilbert", "FlaxDistilBertForTokenClassification"),
("roberta", "FlaxRobertaForTokenClassification"),
("bert", "FlaxBertForTokenClassification"),
("big_bird", "FlaxBigBirdForTokenClassification"),
Expand All @@ -135,6 +140,7 @@
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
("distilbert", "FlaxDistilBertForMultipleChoice"),
("roberta", "FlaxRobertaForMultipleChoice"),
("bert", "FlaxBertForMultipleChoice"),
("big_bird", "FlaxBigBirdForMultipleChoice"),
Expand Down
24 changes: 23 additions & 1 deletion src/transformers/models/distilbert/__init__.py
Expand Up @@ -18,7 +18,7 @@

from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available


_import_structure = {
Expand Down Expand Up @@ -58,6 +58,17 @@
"TFDistilBertPreTrainedModel",
]

if is_flax_available():
_import_structure["modeling_flax_distilbert"] = [
"FlaxDistilBertForMaskedLM",
"FlaxDistilBertForMultipleChoice",
"FlaxDistilBertForQuestionAnswering",
"FlaxDistilBertForSequenceClassification",
"FlaxDistilBertForTokenClassification",
"FlaxDistilBertModel",
"FlaxDistilBertPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_distilbert import (
Expand Down Expand Up @@ -95,6 +106,17 @@
TFDistilBertPreTrainedModel,
)

if is_flax_available():
from .modeling_flax_distilbert import (
FlaxDistilBertForMaskedLM,
FlaxDistilBertForMultipleChoice,
FlaxDistilBertForQuestionAnswering,
FlaxDistilBertForSequenceClassification,
FlaxDistilBertForTokenClassification,
FlaxDistilBertModel,
FlaxDistilBertPreTrainedModel,
)

else:
import sys

Expand Down