Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax Encoder Decoder] Make Flax GPT2 working with cross attention #13008

Merged
merged 35 commits into from Aug 23, 2021

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Aug 5, 2021

What does this PR do?

The current Flax's GPT2 doesn't support cross attention, while PyTorch's GPT2 does. This PR add cross attention to Flax's GPT2, closely following the codes in PyTorch's GPT2 and Flax's Bart models.

However, I add one more thing, which is the projection from the encoder's last hidden state to the dimension size of the decoder's hidden states. I think this is useful when we want to combine GPT2 with different pretrained encoders (in particular, image encoders like ViT or CLIPVision).

project_encoder = getattr(self.config, "project_encoder", None)
if project_encoder:
    encoder_hidden_states = self.encoder_projection_ln(encoder_hidden_states)
    feed_forward_hidden_states = self.encoder_projection_mlp(
        encoder_hidden_states, deterministic=deterministic
    )
    # residual connection
    encoder_hidden_states = feed_forward_hidden_states

If HuggingFace thinks it is better not to include this (so it would be more identical to PyTorch's version), I will remove it.

Finally, is there any documentation file I need to edit for this change? If so, could you indicate me which file(s), please?

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@patrickvonplaten
@patil-suraj

@ydshieh ydshieh changed the title [WIP] Make Flax GPT2 working with cross attention Make Flax GPT2 working with cross attention Aug 5, 2021
@ydshieh ydshieh marked this pull request as ready for review August 5, 2021 07:24
@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Aug 6, 2021

This is a great PR @ydshieh! Thanks a lot for working on this! :-) The PR looks great - that's exactly how I would have implemented it as well.

It would be great if you could remove the encoder->decoder projection layer in a first PR to make it consistent with PyTocrh. Also we will probably have to add a FlaxEncoderDecoder model architecture file in addition to this to showcase how GPT2 can be used with cross attention and to test this new feature.

The FlaxEncoderDecoder model should look very similar to the PyTorch implementation: https://github.com/huggingface/transformers/blob/master/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py . We can also get some inspiration from https://github.com/gchhablani/multilingual-image-captioning/blob/main/models/flax_clip_vision_mbart/modeling_clip_vision_mbart.py . We'll have to make it more general though cc @gchhablani @bhadreshpsavani

=> It's important that we test newly added features (such as GPT2's cross attention layer) so I think we'll have to add modeling_flax_encoder_decoder.py right away. This will definitely require some more work. If you are interested in giving it a shot @ydshieh that would be great - otherwise I can also continue this PR next week :-)

@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 6, 2021

@patrickvonplaten Thanks for the feedback, I will remove the encoder->decoder projection layer.

Yes, I would like to work on FlaxEncoderDecoder, it is a great learning chance. If I understand correctly, you prefer FlaxEncoderDecoder being included in this PR, rather than in a separate PR, right?

@patrickvonplaten
Copy link
Contributor

Excatly let's include it in this PR so that we can then also add a first test for it with GPT2, like this one for PyTorch:

class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):

@gchhablani
Copy link
Contributor

gchhablani commented Aug 7, 2021

@ydshieh This is great! Do let me know if I can help in any way.

@ydshieh ydshieh marked this pull request as draft August 9, 2021 14:32
@ydshieh ydshieh changed the title Make Flax GPT2 working with cross attention [WIP] Make Flax GPT2 working with cross attention Aug 9, 2021
@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 10, 2021

Hi, @patrickvonplaten

Here is my first attempt to FlaxEncoderDecoderModel. However, I have 3 main questions - when you have time, could you give some suggestions for them, please?

  1. The __call__/encode/decode methods in Flax models (and modules) don't seem to have **kwargs, at least, not in FlaxBartModel code.

    The current version of FlaxEncoderDecoderModel don't have token_type_ids parameter, and might have problems when the decoder module is FlaxBertModule, because it requires token_type_ids argument.

    Do you have a better idea to deal with the token_type_ids parameter?

    - Try to add it explicity in the methods' parameters, like `position_ids`?
    - Or there is a good way to use `**kwargs` in this case?
    
  2. In self.__call__(), when decoder_input_ids is None, we use shift_tokens_right() and it requires decoder_start_token_id.

    However, self.config (EncoderDecoderConfig), or even self.config.decoder, might not have decoder_start_token_id defined.

    - Should we try to add `decoder_start_token_id` in `self.from_encoder_decoder_pretrained()`, using similar logic in `generation_utils._get_decoder_start_token_id()`?
    - Or we just leave the users to specify it (when it is not already in the config)?
    
  3. In modeling_encoder_deocer.EncoderDecoderModel.prepare_inputs_for_generation(), we use the decoder model's
    prepare_inputs_for_generation():

       decoder_inputs = self.decoder.prepare_inputs_for_generation(decoder_input_ids, ...) 
    
However, in Flax's version, we only have the decoder module, not the
decoder model. Is the current `FlaxEncoderDecoderModel.prepare_inputs_for_generation()` implementation OK?

There are 5 other comments starting with "# Q: ". It would be great if you can also have some feedbacks on them, but they are less important.

@patrickvonplaten
Copy link
Contributor

Hey @ydshieh,

The PR already seems to be in a great shape - awesome job! Thanks a lot for being so clear about your questions - I answered most of them in the comments above.

In short:

  • let's just remove token_type_ids for FlaxEncoderDecoder for now
  • decoder_input_ids should never be generated from input_ids here, the user should be forced to pass them
  • we should define a decodefunction andprepare_inputs_for_generationsimilar to how it's done forFlaxBart`
  • The goal of this PR should really be to enable tests like those:
    class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):

Note that this PR won't (yet) enable generic ImageToText but just generic TextToText with GPT2 as the decoder. In a follow-up PR we will then define a new FlaxImageEncoderDecoder class specifically for ImageToText. However it's the much better approach in my opinion to start with TextToText (as you're doing it here) where we can more or less translate most of the code from PyTorch.

Please let me know if anything is unclear! I'm more than happy to also take a deeper look if you're stuck somewhere :-)

@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 13, 2021

Hey @patrickvonplaten Thanks for all the feedback. I will continue the remaining work, including the test script as you mentioned.

Since we decide not to consider token_type_ids for now, I will need to change the example in the model file from bert2gpt2 = ... to gpt2togpt2 = ..., otherwise the example won't run (can't even initialize the model). I tested locally with

FlaxEncoderDecoderModel.from_encoder_decoder_pretrained('gpt2', 'gpt2') 

and it at least can run __call__. Unless you have other ideas for a pair for the example, I am going for it :)

@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 13, 2021

Hi @patrickvonplaten , I have made FlaxEncoderDecoder available to the library. It remains to add the test file :)

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Aug 16, 2021

Hey @ydshieh,

This PR is already in a very good shape. I'm very impressed by how well you implemented this class! The EncoderDecoderModel is one of the most difficult classes to implement.

I've added a Bert2GPT2 test that verifies that your PR works correctly (it does - great job ;-)). I think the only thing left to do now is to change the examples and tests from "GPT2toGPT2" to "BERT2GPT2" and then we can merge this one :-)

@ydshieh ydshieh changed the title [WIP] Make Flax GPT2 working with cross attention Make Flax GPT2 working with cross attention Aug 16, 2021
@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 17, 2021

Hi @patrickvonplaten , I changed all remaining examples & tests to bert2gpt2, and rename EncoderDecoderModelTest to FlaxEncoderDecoderModelTest. The only remark is: FlaxEncoderDecoderModel doesn't treat position_ids and token_type_ids, because it all depends on each encoder/decoder models (modules actually), and it seems to me we don't pass **kwargs to module.apply. (It would be great If you can say something about this - I am not sure, just my observation).

Other than this, I think the task is done :)

rngs=rngs,
)

def prepare_inputs_for_generation(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj In PyTorch we actually shift the reponsability of creating the model inputs to the decoder class, e.g. we call:

self.decoder.prepare_inputs_for_generation(...)

In Flax this is more difficult since we can't do self.decoder.prepare_inputs_for_generation as its a module and not a model. I'm very much fine with the current solution as it covers all the relevant cases IMO (position_ids & attention_mask).

So I'm happy to go forward with it - let me know if you disagree here

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ydshieh,

Amazing work, I have some final nits and I'd like to have one review from @patil-suraj, but I think apart from this, the PR is ready to be merged :-)

Once @patil-suraj gives his OK, I'll run all slow tests for JAX GPT2 and JAX Encoder Decoder model once and we can merge after :-)

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
@patrickvonplaten patrickvonplaten merged commit 2e20c0f into huggingface:master Aug 23, 2021
@patrickvonplaten patrickvonplaten changed the title Make Flax GPT2 working with cross attention [Flax Encoder Decoder] Make Flax GPT2 working with cross attention Aug 23, 2021
@patrickvonplaten
Copy link
Contributor

@ydshieh amazing job on adding the Flax encoder decoder class! This lays the foundation for the FlaxVisionEncoderDecoder framework :-)

I'm currently working on adding a SpeechEncoderDecoder model here: https://github.com/huggingface/transformers/blob/19106d1c5548b3083c1d5ced667de6854367f1e0/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py - the FlaxVisionEncoderDecoder would be added in a similar spirit. If you would be interested we could try to add this class in a follow-up PR :-)

@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 23, 2021

@patrickvonplaten Sure, I would like to continue with it. Actually, I just finished TFEncoderDecoderModel and add cross attention to some TF models (Bert/GPT2/Roberta/Electra). In particular, the test for test_bert2gpt2_summarization and test_bert2bert_summarization works in TF version now (after some bug fixes in the library though). I tested them locally with @slow disabled.

I need to finalize it, and will request a review (maybe for someone else? not sure if you work with TF)

I think the implementation for VisionEncoderDecoder will be straightforward, right? I mean basically, just change the parameters to pixel_values, and probably add some feature extraction part.

Here is a preview for TFEncoderDecoderModel :)
#13222

@ydshieh ydshieh deleted the flax_gpt2_with_cross_attn branch May 5, 2022 10:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants