Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the state-of-the-art contrastive search decoding methods for the codebase of generation_utils.py #19477

Merged
merged 39 commits into from Oct 19, 2022

Conversation

gmftbyGMFTBY
Copy link
Contributor

@gmftbyGMFTBY gmftbyGMFTBY commented Oct 11, 2022

Adding the state-of-the-art contrastive search decoding method for the generation_utils codebase

Fixes #19182

In this PR, I add the source codes of our proposed state-of-the-art decoding methods for the off-the-shelf neural text generation models. The main changes are in the following files: (1) src/transformers/generation_utils.py; (2) examples/pytorch/text-generation/run_generation_contrastive_search.py. To run the test script, please follow these commands:

cd examples/pytorch/text-generation;
CUDA_VISIBLE_DEVICES=0 python run_generation_contrastive_search.py --model_type=gpt2 --model_name_or_path=gpt2-large

Before submitting

Who can review?

According to the suggestions of @gante, @patrickvonplaten and @sgugger can review this PR.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 11, 2022

The documentation is not available anymore as the PR was closed or merged.

@gante
Copy link
Member

gante commented Oct 11, 2022

@sgugger @patrickvonplaten context: this is the implementation by the authors of this NeurIPS paper, as first proposed in #19182 -- a new generation strategy with very interesting results!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this new method. The added code in generate looks clean to me, I just left a couple of nits.

For the example, it would really be good if you could write a new one leveraging the auto-APIs instead of copying the old run_generation.py which is severly outdated and only works for very few models.

src/transformers/generation_utils.py Outdated Show resolved Hide resolved
src/transformers/generation_utils.py Outdated Show resolved Hide resolved
@@ -3446,3 +3766,152 @@ def top_k_top_p_filtering(
)

return logits


# ========== utils for contrastive search decoding method ========= #
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those utils would probably be best in a submodule (like we have generation_beam_search or generation_beam_constraint) to avoid generation_utils being too big.

src/transformers/generation_utils.py Show resolved Hide resolved
Comment on lines 104 to 110
# kwargs["language"] = tokenizer.lang2id[language]

# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
# XLM masked-language modeling (MLM) models need masked token
# is_xlm_mlm = "mlm" in args.model_name_or_path
# if is_xlm_mlm:
# kwargs["mask_token_id"] = tokenizer.mask_token_id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clean up?

@gmftbyGMFTBY
Copy link
Contributor Author

Hi, @sgugger, thank you so much for your suggestions. I will fix these problems quickly!

@gmftbyGMFTBY
Copy link
Contributor Author

Hello @sgugger, is there any document or introduction for auto-APIs?

@sgugger
Copy link
Collaborator

sgugger commented Oct 12, 2022

The documentation would be the place to start. You can also look at all other examples!

@gmftbyGMFTBY
Copy link
Contributor Author

Hello, @sgugger, I have fixed the problems based on your valuable suggestions! Besides, I have updated the test scripts to the auto-APIs of inference.

The command line to run this test script can be found in its docstring.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly have nits on the docstrings. The new example looks great, thanks a lot!

docs/source/en/main_classes/text_generation.mdx Outdated Show resolved Hide resolved
src/transformers/generation_contrastive_search.py Outdated Show resolved Hide resolved
src/transformers/generation_contrastive_search.py Outdated Show resolved Hide resolved
src/transformers/generation_contrastive_search.py Outdated Show resolved Hide resolved
src/transformers/generation_contrastive_search.py Outdated Show resolved Hide resolved
src/transformers/generation_contrastive_search.py Outdated Show resolved Hide resolved
src/transformers/generation_utils.py Outdated Show resolved Hide resolved
@gmftbyGMFTBY
Copy link
Contributor Author

Oh, I am still working on the integration test.

@gante
Copy link
Member

gante commented Oct 18, 2022

@gmftbyGMFTBY you probably have to add the @slow decorator to the test, and run it locally with RUN_SLOW=1 py.test (...) to confirm that it is working.

Our CI doesn't run tests with @slow on push (and fails if the test doesn't have the decorator and is actually slow), but we run them every 24h and track them internally :)

@gmftbyGMFTBY
Copy link
Contributor Author

Ok, I got it!

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, seems ready to go in (except for the minor comment I added)

@patrickvonplaten can you do a final check, please? :)

Comment on lines 1503 to 1510
# 10. prepare logits warper: get the TopKLogitsWarper for contrastive_search
logits_warper = self._get_logits_warper(
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
num_beams=num_beams,
renormalize_logits=renormalize_logits,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't have been removed :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During the test, I found that adding the logits_warper (the TopKLogitsWarper is used) could influence the generations of contrastive search. Because the TopKLogitsWarper filters the logits of other tokens (not Top-k tokens) and calculates the softmax, the model confidence is different from the case that the TopKLogitsWarper is not used.

So, in this case, I think the contrastive search should disable the logits_warper by default.

Is there any solution that the TopKLogitsWarper warper is not activated for contrastive search?
For example, edit the _get_logits_warper function and pass the penalty_alpha parameter to indicate whether the top_k parameter is used?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. In that case, I think we can do without the logits warper for now :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay!

return self.contrastive_search(
input_ids,
top_k=top_k,
penalty_alpha=penalty_alpha,
logits_processor=logits_processor,
logits_warper=logits_warper,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(related to the comment above)

logits_processor: Optional[LogitsProcessorList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
max_length: Optional[int] = None,

stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
max_length (`int`, *optional*, defaults to 20):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
max_length (`int`, *optional*, defaults to 20):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe not add a deprecated argument :-)

logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should remove this if statement

)
# compute the candidate tokens by the language model and collects their hidden_states
output = self(output_hidden_states=True, **next_model_inputs)
past_key_values = output.past_key_values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) - not all language gerenation models can return past_key_values (e.g. TransfoXL or XLNet) these models are still surprisingly used a lot:

Maybe we could add a better error message here?

Suggested change
past_key_values = output.past_key_values
if "past_key_values" not in output:
raise ValueError(f"self.__class__ cannot return `past_key_values` and can therefore **not** be used for contrastive generation.")
past_key_values = output.past_key_values

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay!

items = []
# item is either the key or the value matrix
for item in layer:
bsz_and_beam, num_head, seq_len, esz = item.size()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it holds always true that the past_key_values have this size. Did we test contrastive search on all of the following models:

  • GPT2
  • T5
  • GPT-J
  • BART
    ?

It should work at least on those 4 models

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, GPT2, T5, BART, GPT-J, and OPT models work fine.

Our implementation is compatible with the encoder-decoder models (the degeneration penalty is calculated on the decoder's hidden states). But we didn't carefully conduct the human evaluation of the encoder-decoder models, such as T5 and BART. Whether contrastive search could significantly boost their performance is still an open problem for us.

@@ -1693,6 +1693,25 @@ def test_diverse_beam_search(self):
],
)

@slow
def test_contrastive_search(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible I'd be really happy if we could also test this on BART, T5 and GPT-J . Then we should have covered 95% of the model architectures. But ok to do in a follow-up PR . Currently I don't expect the method to work with T5.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me to merge, but it'd be nice to always make sure the method works for all T5, BART, GPT2 and GPT-J. Also, we currently have only slow tests which is dangerous given that changes in generate can also affect contrastive search.

If it doesn't take too much time, I'd advocate to at least add 7 more tests:

  • 4 fast tests with dummy models that just check constrastive search outputs the correct shape, one for each GPT2, T5, BART, GPT-J
  • 3 more slow tests exactly like test_contrastive_search for T5, BART, GPT-J

I leave it up to you @gante to decide :-)

Overall, great work! Thanks a lot @gmftbyGMFTBY for adapting the code so quickly here!

EDIT: Sorry, actually let's please remove the deprecated max_length before merging - that's actually a "must-do" before merging IMO (so not a fully approval here 😅 )

@gmftbyGMFTBY
Copy link
Contributor Author

Okay, I am working on it! Thanks a lot for your reviews!

@patrickvonplaten
Copy link
Contributor

BTW @gmftbyGMFTBY,

Just read a through your extremely nice issue! It seems like you experimented with OPT as well, so maybe let's add a test for OPT as well then ? :-) OPT's past_key_values are slightly different compared to GPT2's past_key_values so maybe instead of adding a test for GPT-J and GPT-2, it would make more sense to add a test for OPT in addition to GPT2?

Also, if the paper is only concerned with open-ended generation (so less with encoder-decoder architectures), I'm also totally fine with not testing for T5 and BART (it's a nice to have, but if it takes too much time and it's not too important - happy to skip it!).

Regarding the fast dummy test, could you maybe make use of those dummy models:

The tests colud look very similar to:

def test_max_new_tokens_decoder_only(self):

just much shorter, i.e. they only need to test for shape equality.

@gmftbyGMFTBY
Copy link
Contributor Author

Yeah, we have already tested the OPT models, and it works fine. I will supply more tests to the pre-trained models that you mentioned.

@gmftbyGMFTBY
Copy link
Contributor Author

gmftbyGMFTBY commented Oct 19, 2022

@patrickvonplaten more tests about these models are added:

  • gpt2-large
  • gpt-j (EleutherAI/gpt-j-6B)
  • opt (facebook/opt-6.7b)
  • BART (facebook/bart-large-cnn)
  • T5 (flax-community/t5-base-cnn-dm)

These tests are passed successfully. Can you do the final check about this PR?

@gante
Copy link
Member

gante commented Oct 19, 2022

Thank you for being part of this process @gmftbyGMFTBY 🙌 All queries have been addressed and the PR looks in a good state, merging!

@gante gante merged commit 71786b1 into huggingface:main Oct 19, 2022
@gmftbyGMFTBY
Copy link
Contributor Author

gmftbyGMFTBY commented Oct 19, 2022

@gante @patrickvonplaten @sgugger Wow, Thank you very much for your help and support. Love huggingface team!

@yxuansu
Copy link

yxuansu commented Oct 19, 2022

@gante @patrickvonplaten @sgugger -- Many thanks for your kind help throughout the process! It means a great deal to me and @gmftbyGMFTBY. Huggingface is the best!

@patrickvonplaten
Copy link
Contributor

Great work @gmftbyGMFTBY and @yxuansu, thanks for bearing with us through the PR :-)

kashif pushed a commit to kashif/transformers that referenced this pull request Oct 21, 2022
…he codebase of generation_utils.py (huggingface#19477)

* add: the contrastive search for generaton_utils

* add: testing scripts for contrastive search under examples/text-generation

* update the quality of codes

* revise the docstring; make the generation_contrastive_search.py scripts;

* revise the examples/pytorch/text-generation/run_generation_contrastive_search.py to the auto-APIs format

* revise the necessary documents

* fix: revise the docstring of generation_contrastive_search.py

* Fix the code indentation

* fix: revise the nits and examples in contrastive_search docstring.

* fix the copyright

* delete generation_contrastive_search.py

* revise the logic in contrastive_search

* update the intergration test and the docstring

* run the tests over

* add the slow decorate to the contrastive_search intergrate test

* add more test

* do the style, quality, consistency checks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adding State-of-the-art Contrastive Search to the Codebase of model.generate()
6 participants