Skip to content

Commit

Permalink
Make Flax GPT2 working with cross attention (#13008)
Browse files Browse the repository at this point in the history
* make flax gpt2 working with cross attention

* Remove encoder->decoder projection layer

* A draft (incomplete) for FlaxEncoderDecoderModel

* Add the method from_encoder_decoder_pretrained + the docstrings

* Fix the mistakes of using EncoderDecoderModel

* Fix style

* Add FlaxEncoderDecoderModel to the library

* Fix cyclic imports

* Add FlaxEncoderDecoderModel to modeling_flax_auto.py

* Remove question comments

* add tests for FlaxEncoderDecoderModel

* add flax_encoder_decoder to the lists of ignored entries in check_repo.py

* fix missing required positional arguments

* Remove **kwargs when creating FlaxEncoderDecoderModel in from_encoder_decoder_pretrained()

Also fix generation eos/pad tokens issue

* Fix: Use sequences from the generated_output

* Change a check from assert to raise ValueError

* Fix examples and token ids issues

* Fix missing all_cross_attentions when outputting tuple in modeling_gpt2

* Remove the changes in configuration docstrings.

* allow for bert 2 gpt2

* make fix-copies

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Change remaining examples to bert2gpt2

* Change the test to Bert2GPT2

* Fix examples

* Fix import

* Fix unpack bug

* Rename to FlaxEncoderDecoderModelTest and change the test to bert2gpt2

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Fix: NotImplentedError -> NotImplementedError

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* up

* finalize

Co-authored-by: ydshieh <ydshieh@user.noreply>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
4 people committed Aug 23, 2021
1 parent 7223844 commit 2e20c0f
Show file tree
Hide file tree
Showing 15 changed files with 1,435 additions and 46 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Expand Up @@ -356,7 +356,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ELECTRA ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Encoder decoder ||||| |
| Encoder decoder ||||| |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| FairSeq Machine-Translation ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
7 changes: 7 additions & 0 deletions docs/source/model_doc/encoderdecoder.rst
Expand Up @@ -40,3 +40,10 @@ EncoderDecoderModel

.. autoclass:: transformers.EncoderDecoderModel
:members: forward, from_encoder_decoder_pretrained


FlaxEncoderDecoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxEncoderDecoderModel
:members: __call__, from_encoder_decoder_pretrained
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -1703,6 +1703,7 @@
"FlaxElectraPreTrainedModel",
]
)
_import_structure["models.encoder_decoder"].append("FlaxEncoderDecoderModel")
_import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"])
_import_structure["models.gpt_neo"].extend(
["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
Expand Down Expand Up @@ -3171,6 +3172,7 @@
FlaxElectraModel,
FlaxElectraPreTrainedModel,
)
from .models.encoder_decoder import FlaxEncoderDecoderModel
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel
from .models.marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Expand Up @@ -79,6 +79,7 @@
("t5", "FlaxT5ForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("marian", "FlaxMarianMTModel"),
("encoder-decoder", "FlaxEncoderDecoderModel"),
]
)

Expand Down
12 changes: 10 additions & 2 deletions src/transformers/models/bert/modeling_flax_bert.py
Expand Up @@ -625,13 +625,21 @@ def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
token_type_ids: Optional[np.ndarray] = None,
position_ids: Optional[np.ndarray] = None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# make sure `token_type_ids` is correctly initialized when not passed
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)

# make sure `position_ids` is correctly initialized when not passed
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

hidden_states = self.embeddings(
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
)
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/encoder_decoder/__init__.py
Expand Up @@ -18,7 +18,7 @@

from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_torch_available


_import_structure = {
Expand All @@ -28,13 +28,18 @@
if is_torch_available():
_import_structure["modeling_encoder_decoder"] = ["EncoderDecoderModel"]

if is_flax_available():
_import_structure["modeling_flax_encoder_decoder"] = ["FlaxEncoderDecoderModel"]

if TYPE_CHECKING:
from .configuration_encoder_decoder import EncoderDecoderConfig

if is_torch_available():
from .modeling_encoder_decoder import EncoderDecoderModel

if is_flax_available():
from .modeling_flax_encoder_decoder import FlaxEncoderDecoderModel

else:
import sys

Expand Down

Large diffs are not rendered by default.

0 comments on commit 2e20c0f

Please sign in to comment.