New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flax Encoder Decoder] Make Flax GPT2 working with cross attention #13008
[Flax Encoder Decoder] Make Flax GPT2 working with cross attention #13008
Conversation
This is a great PR @ydshieh! Thanks a lot for working on this! :-) The PR looks great - that's exactly how I would have implemented it as well. It would be great if you could remove the encoder->decoder projection layer in a first PR to make it consistent with PyTocrh. Also we will probably have to add a The => It's important that we test newly added features (such as GPT2's cross attention layer) so I think we'll have to add |
@patrickvonplaten Thanks for the feedback, I will remove the encoder->decoder projection layer. Yes, I would like to work on |
Excatly let's include it in this PR so that we can then also add a first test for it with GPT2, like this one for PyTorch:
|
@ydshieh This is great! Do let me know if I can help in any way. |
Here is my first attempt to
There are 5 other comments starting with "# Q: ". It would be great if you can also have some feedbacks on them, but they are less important. |
src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
Outdated
Show resolved
Hide resolved
src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
Outdated
Show resolved
Hide resolved
src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
Outdated
Show resolved
Hide resolved
src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
Outdated
Show resolved
Hide resolved
src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
Outdated
Show resolved
Hide resolved
Hey @ydshieh, The PR already seems to be in a great shape - awesome job! Thanks a lot for being so clear about your questions - I answered most of them in the comments above. In short:
Note that this PR won't (yet) enable generic ImageToText but just generic TextToText with GPT2 as the decoder. In a follow-up PR we will then define a new Please let me know if anything is unclear! I'm more than happy to also take a deeper look if you're stuck somewhere :-) |
Hey @patrickvonplaten Thanks for all the feedback. I will continue the remaining work, including the test script as you mentioned. Since we decide not to consider
and it at least can run |
13b506b
to
a36d7c3
Compare
Hi @patrickvonplaten , I have made FlaxEncoderDecoder available to the library. It remains to add the test file :) |
Hey @ydshieh, This PR is already in a very good shape. I'm very impressed by how well you implemented this class! The I've added a Bert2GPT2 test that verifies that your PR works correctly (it does - great job ;-)). I think the only thing left to do now is to change the examples and tests from |
…h/transformers into flax_gpt2_with_cross_attn
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
…h/transformers into flax_gpt2_with_cross_attn
Hi @patrickvonplaten , I changed all remaining examples & tests to bert2gpt2, and rename Other than this, I think the task is done :) |
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
Outdated
Show resolved
Hide resolved
src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
Outdated
Show resolved
Hide resolved
rngs=rngs, | ||
) | ||
|
||
def prepare_inputs_for_generation( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj In PyTorch we actually shift the reponsability of creating the model inputs to the decoder class, e.g. we call:
self.decoder.prepare_inputs_for_generation(...)
In Flax this is more difficult since we can't do self.decoder.prepare_inputs_for_generation
as its a module and not a model. I'm very much fine with the current solution as it covers all the relevant cases IMO (position_ids & attention_mask).
So I'm happy to go forward with it - let me know if you disagree here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @ydshieh,
Amazing work, I have some final nits and I'd like to have one review from @patil-suraj, but I think apart from this, the PR is ready to be merged :-)
Once @patil-suraj gives his OK, I'll run all slow tests for JAX GPT2 and JAX Encoder Decoder model once and we can merge after :-)
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
@ydshieh amazing job on adding the Flax encoder decoder class! This lays the foundation for the I'm currently working on adding a |
@patrickvonplaten Sure, I would like to continue with it. Actually, I just finished I need to finalize it, and will request a review (maybe for someone else? not sure if you work with TF) I think the implementation for Here is a preview for |
What does this PR do?
The current Flax's GPT2 doesn't support cross attention, while PyTorch's GPT2 does. This PR add cross attention to Flax's GPT2, closely following the codes in PyTorch's GPT2 and Flax's Bart models.
However, I add one more thing, which is the projection from the encoder's last hidden state to the dimension size of the decoder's hidden states. I think this is useful when we want to combine GPT2 with different pretrained encoders (in particular, image encoders like ViT or CLIPVision).
If HuggingFace thinks it is better not to include this (so it would be more identical to PyTorch's version), I will remove it.
Finally, is there any documentation file I need to edit for this change? If so, could you indicate me which file(s), please?
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@patrickvonplaten
@patil-suraj