Skip to content

Commit

Permalink
albert flax
Browse files Browse the repository at this point in the history
  • Loading branch information
kamalkraj committed Aug 30, 2021
1 parent d506495 commit 2cfaf97
Show file tree
Hide file tree
Showing 8 changed files with 1,477 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Expand Up @@ -321,7 +321,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Model | Tokenizer slow | Tokenizer fast | PyTorch support | TensorFlow support | Flax Support |
+=============================+================+================+=================+====================+==============+
| ALBERT ||||| |
| ALBERT ||||| |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BART ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
52 changes: 51 additions & 1 deletion docs/source/model_doc/albert.rst
Expand Up @@ -43,7 +43,8 @@ Tips:
similar to a BERT-like architecture with the same number of hidden layers as it has to iterate through the same
number of (repeating) layers.

This model was contributed by `lysandre <https://huggingface.co/lysandre>`__. The original code can be found `here
This model was contributed by `lysandre <https://huggingface.co/lysandre>`__. This model jax version was contributed by
`kamalkraj <https://huggingface.co/kamalkraj>`__. The original code can be found `here
<https://github.com/google-research/ALBERT>`__.

AlbertConfig
Expand Down Expand Up @@ -174,3 +175,52 @@ TFAlbertForQuestionAnswering

.. autoclass:: transformers.TFAlbertForQuestionAnswering
:members: call


FlaxAlbertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxAlbertModel
:members: __call__


FlaxAlbertForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxAlbertForPreTraining
:members: __call__


FlaxAlbertForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxAlbertForMaskedLM
:members: __call__


FlaxAlbertForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxAlbertForSequenceClassification
:members: __call__


FlaxAlbertForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxAlbertForMultipleChoice
:members: __call__


FlaxAlbertForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxAlbertForTokenClassification
:members: __call__


FlaxAlbertForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxAlbertForQuestionAnswering
:members: __call__
22 changes: 22 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -1642,6 +1642,18 @@
"FlaxTopPLogitsWarper",
]
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
_import_structure["models.albert"].extend(
[
"FlaxAlbertForMaskedLM",
"FlaxAlbertForMultipleChoice",
"FlaxAlbertForPreTraining",
"FlaxAlbertForQuestionAnswering",
"FlaxAlbertForSequenceClassification",
"FlaxAlbertForTokenClassification",
"FlaxAlbertModel",
"FlaxAlbertPreTrainedModel",
]
)
_import_structure["models.auto"].extend(
[
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
Expand Down Expand Up @@ -3152,6 +3164,16 @@
FlaxTopPLogitsWarper,
)
from .modeling_flax_utils import FlaxPreTrainedModel
from .models.albert import (
FlaxAlbertForMaskedLM,
FlaxAlbertForMultipleChoice,
FlaxAlbertForPreTraining,
FlaxAlbertForQuestionAnswering,
FlaxAlbertForSequenceClassification,
FlaxAlbertForTokenClassification,
FlaxAlbertModel,
FlaxAlbertPreTrainedModel,
)
from .models.auto import (
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
Expand Down
23 changes: 23 additions & 0 deletions src/transformers/models/albert/__init__.py
Expand Up @@ -20,6 +20,7 @@

from ...file_utils import (
_LazyModule,
is_flax_available,
is_sentencepiece_available,
is_tf_available,
is_tokenizers_available,
Expand Down Expand Up @@ -65,6 +66,17 @@
"TFAlbertPreTrainedModel",
]

if is_flax_available():
_import_structure["modeling_flax_albert"] = [
"FlaxAlbertForMaskedLM",
"FlaxAlbertForMultipleChoice",
"FlaxAlbertForPreTraining",
"FlaxAlbertForQuestionAnswering",
"FlaxAlbertForSequenceClassification",
"FlaxAlbertForTokenClassification",
"FlaxAlbertModel",
"FlaxAlbertPreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig, AlbertOnnxConfig
Expand Down Expand Up @@ -103,6 +115,17 @@
TFAlbertPreTrainedModel,
)

if is_flax_available():
from .modeling_flax_albert import (
FlaxAlbertForMaskedLM,
FlaxAlbertForMultipleChoice,
FlaxAlbertForPreTraining,
FlaxAlbertForQuestionAnswering,
FlaxAlbertForSequenceClassification,
FlaxAlbertForTokenClassification,
FlaxAlbertModel,
FlaxAlbertPreTrainedModel,
)
else:
import sys

Expand Down

0 comments on commit 2cfaf97

Please sign in to comment.