Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax Encoder Decoder] Make Flax GPT2 working with cross attention #13008

Merged
merged 35 commits into from Aug 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d45ea91
make flax gpt2 working with cross attention
Aug 4, 2021
d8069c7
Remove encoder->decoder projection layer
ydshieh Aug 6, 2021
2c68bc4
A draft (incomplete) for FlaxEncoderDecoderModel
ydshieh Aug 9, 2021
4d3c963
Add the method from_encoder_decoder_pretrained + the docstrings
ydshieh Aug 10, 2021
b71e97e
Fix the mistakes of using EncoderDecoderModel
ydshieh Aug 10, 2021
02b52fc
Fix style
ydshieh Aug 10, 2021
d18835a
Add FlaxEncoderDecoderModel to the library
ydshieh Aug 13, 2021
281fa2a
Fix cyclic imports
ydshieh Aug 13, 2021
b6f5f4f
Add FlaxEncoderDecoderModel to modeling_flax_auto.py
ydshieh Aug 13, 2021
a36d7c3
Remove question comments
ydshieh Aug 13, 2021
8813d90
add tests for FlaxEncoderDecoderModel
ydshieh Aug 14, 2021
b253fee
add flax_encoder_decoder to the lists of ignored entries in check_rep…
ydshieh Aug 14, 2021
d0fe4b3
fix missing required positional arguments
ydshieh Aug 14, 2021
e79d5d7
Remove **kwargs when creating FlaxEncoderDecoderModel in from_encoder…
ydshieh Aug 14, 2021
27e1224
Fix: Use sequences from the generated_output
ydshieh Aug 14, 2021
a3e4122
Change a check from assert to raise ValueError
ydshieh Aug 14, 2021
3850997
Fix examples and token ids issues
ydshieh Aug 14, 2021
10de83e
Fix missing all_cross_attentions when outputting tuple in modeling_gpt2
ydshieh Aug 15, 2021
70dd479
Remove the changes in configuration docstrings.
ydshieh Aug 16, 2021
82ac950
allow for bert 2 gpt2
patrickvonplaten Aug 16, 2021
f6906f0
make fix-copies
patrickvonplaten Aug 16, 2021
a416698
Merge branch 'flax_gpt2_with_cross_attn' of https://github.com/ydshie…
ydshieh Aug 16, 2021
62a9bd7
Apply suggestions from code review
ydshieh Aug 16, 2021
c94754a
Merge branch 'flax_gpt2_with_cross_attn' of https://github.com/ydshie…
ydshieh Aug 16, 2021
5b4849f
Change remaining examples to bert2gpt2
ydshieh Aug 16, 2021
cf00378
Change the test to Bert2GPT2
ydshieh Aug 16, 2021
e1d5739
Fix examples
ydshieh Aug 16, 2021
7410649
Fix import
ydshieh Aug 16, 2021
7f29ee5
Fix unpack bug
ydshieh Aug 16, 2021
fa0dde9
Rename to FlaxEncoderDecoderModelTest and change the test to bert2gpt2
ydshieh Aug 17, 2021
109b5d1
Apply suggestions from code review
ydshieh Aug 17, 2021
369b3aa
Fix: NotImplentedError -> NotImplementedError
ydshieh Aug 17, 2021
d617f0e
Apply suggestions from code review
ydshieh Aug 18, 2021
0adab76
up
patrickvonplaten Aug 23, 2021
cafc9db
finalize
patrickvonplaten Aug 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/index.rst
Expand Up @@ -353,7 +353,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Encoder decoder | ❌ | ❌ | ✅ | ❌ | |
| Encoder decoder | ❌ | ❌ | ✅ | ❌ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
7 changes: 7 additions & 0 deletions docs/source/model_doc/encoderdecoder.rst
Expand Up @@ -40,3 +40,10 @@ EncoderDecoderModel

.. autoclass:: transformers.EncoderDecoderModel
:members: forward, from_encoder_decoder_pretrained


FlaxEncoderDecoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxEncoderDecoderModel
:members: __call__, from_encoder_decoder_pretrained
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -1692,6 +1692,7 @@
"FlaxElectraPreTrainedModel",
]
)
_import_structure["models.encoder_decoder"].append("FlaxEncoderDecoderModel")
_import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"])
_import_structure["models.gpt_neo"].extend(
["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
Expand Down Expand Up @@ -3151,6 +3152,7 @@
FlaxElectraModel,
FlaxElectraPreTrainedModel,
)
from .models.encoder_decoder import FlaxEncoderDecoderModel
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel
from .models.marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Expand Up @@ -79,6 +79,7 @@
("t5", "FlaxT5ForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("marian", "FlaxMarianMTModel"),
("encoder-decoder", "FlaxEncoderDecoderModel"),
]
)

Expand Down
12 changes: 10 additions & 2 deletions src/transformers/models/bert/modeling_flax_bert.py
Expand Up @@ -625,13 +625,21 @@ def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
token_type_ids: Optional[np.ndarray] = None,
position_ids: Optional[np.ndarray] = None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# make sure `token_type_ids` is correctly initialized when not passed
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)

# make sure `position_ids` is correctly initialized when not passed
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

hidden_states = self.embeddings(
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
)
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/encoder_decoder/__init__.py
Expand Up @@ -18,7 +18,7 @@

from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_torch_available


_import_structure = {
Expand All @@ -28,13 +28,18 @@
if is_torch_available():
_import_structure["modeling_encoder_decoder"] = ["EncoderDecoderModel"]

if is_flax_available():
_import_structure["modeling_flax_encoder_decoder"] = ["FlaxEncoderDecoderModel"]

if TYPE_CHECKING:
from .configuration_encoder_decoder import EncoderDecoderConfig

if is_torch_available():
from .modeling_encoder_decoder import EncoderDecoderModel

if is_flax_available():
from .modeling_flax_encoder_decoder import FlaxEncoderDecoderModel

else:
import sys

Expand Down

Large diffs are not rendered by default.