Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

albert flax #13294

Merged
merged 6 commits into from Aug 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/index.rst
Expand Up @@ -321,7 +321,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Model | Tokenizer slow | Tokenizer fast | PyTorch support | TensorFlow support | Flax Support |
+=============================+================+================+=================+====================+==============+
| ALBERT | ✅ | ✅ | ✅ | ✅ | |
| ALBERT | ✅ | ✅ | ✅ | ✅ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BART | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
52 changes: 51 additions & 1 deletion docs/source/model_doc/albert.rst
Expand Up @@ -43,7 +43,8 @@ Tips:
similar to a BERT-like architecture with the same number of hidden layers as it has to iterate through the same
number of (repeating) layers.

This model was contributed by `lysandre <https://huggingface.co/lysandre>`__. The original code can be found `here
This model was contributed by `lysandre <https://huggingface.co/lysandre>`__. This model jax version was contributed by
`kamalkraj <https://huggingface.co/kamalkraj>`__. The original code can be found `here
<https://github.com/google-research/ALBERT>`__.

AlbertConfig
Expand Down Expand Up @@ -174,3 +175,52 @@ TFAlbertForQuestionAnswering

.. autoclass:: transformers.TFAlbertForQuestionAnswering
:members: call


FlaxAlbertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxAlbertModel
:members: __call__


FlaxAlbertForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxAlbertForPreTraining
:members: __call__


FlaxAlbertForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxAlbertForMaskedLM
:members: __call__


FlaxAlbertForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxAlbertForSequenceClassification
:members: __call__


FlaxAlbertForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxAlbertForMultipleChoice
:members: __call__


FlaxAlbertForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxAlbertForTokenClassification
:members: __call__


FlaxAlbertForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxAlbertForQuestionAnswering
:members: __call__
22 changes: 22 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -1642,6 +1642,18 @@
"FlaxTopPLogitsWarper",
]
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
_import_structure["models.albert"].extend(
[
"FlaxAlbertForMaskedLM",
"FlaxAlbertForMultipleChoice",
"FlaxAlbertForPreTraining",
"FlaxAlbertForQuestionAnswering",
"FlaxAlbertForSequenceClassification",
"FlaxAlbertForTokenClassification",
"FlaxAlbertModel",
"FlaxAlbertPreTrainedModel",
]
)
_import_structure["models.auto"].extend(
[
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
Expand Down Expand Up @@ -3152,6 +3164,16 @@
FlaxTopPLogitsWarper,
)
from .modeling_flax_utils import FlaxPreTrainedModel
from .models.albert import (
FlaxAlbertForMaskedLM,
FlaxAlbertForMultipleChoice,
FlaxAlbertForPreTraining,
FlaxAlbertForQuestionAnswering,
FlaxAlbertForSequenceClassification,
FlaxAlbertForTokenClassification,
FlaxAlbertModel,
FlaxAlbertPreTrainedModel,
)
from .models.auto import (
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
Expand Down
23 changes: 23 additions & 0 deletions src/transformers/models/albert/__init__.py
Expand Up @@ -20,6 +20,7 @@

from ...file_utils import (
_LazyModule,
is_flax_available,
is_sentencepiece_available,
is_tf_available,
is_tokenizers_available,
Expand Down Expand Up @@ -65,6 +66,17 @@
"TFAlbertPreTrainedModel",
]

if is_flax_available():
_import_structure["modeling_flax_albert"] = [
"FlaxAlbertForMaskedLM",
"FlaxAlbertForMultipleChoice",
"FlaxAlbertForPreTraining",
"FlaxAlbertForQuestionAnswering",
"FlaxAlbertForSequenceClassification",
"FlaxAlbertForTokenClassification",
"FlaxAlbertModel",
"FlaxAlbertPreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig, AlbertOnnxConfig
Expand Down Expand Up @@ -103,6 +115,17 @@
TFAlbertPreTrainedModel,
)

if is_flax_available():
from .modeling_flax_albert import (
FlaxAlbertForMaskedLM,
FlaxAlbertForMultipleChoice,
FlaxAlbertForPreTraining,
FlaxAlbertForQuestionAnswering,
FlaxAlbertForSequenceClassification,
FlaxAlbertForTokenClassification,
FlaxAlbertModel,
FlaxAlbertPreTrainedModel,
)
else:
import sys

Expand Down