Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the state-of-the-art contrastive search decoding methods for the codebase of generation_utils.py #19477

Merged
merged 39 commits into from Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
942a7c6
add: the contrastive search for generaton_utils
gmftbyGMFTBY Oct 10, 2022
5909423
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 10, 2022
3e71819
add: testing scripts for contrastive search under examples/text-gener…
gmftbyGMFTBY Oct 10, 2022
47b2b3b
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 11, 2022
41e37a5
update the quality of codes
gmftbyGMFTBY Oct 11, 2022
e278a46
Merge branch 'csearch-pr-v2' of https://github.com/gmftbyGMFTBY/trans…
gmftbyGMFTBY Oct 11, 2022
38b100f
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
9abd1bb
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
32e2a30
revise the docstring; make the generation_contrastive_search.py scripts;
gmftbyGMFTBY Oct 12, 2022
e9e2b26
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
1f1dac2
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
6226b9a
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
f3bfd87
revise the examples/pytorch/text-generation/run_generation_contrastiv…
gmftbyGMFTBY Oct 12, 2022
d3a91b8
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 12, 2022
ce26f9f
revise the necessary documents
gmftbyGMFTBY Oct 12, 2022
e801c6f
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 13, 2022
c78cf91
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 13, 2022
fb4174e
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 13, 2022
68429ad
fix: revise the docstring of generation_contrastive_search.py
gmftbyGMFTBY Oct 13, 2022
e1f0db9
Fix the code indentation
gmftbyGMFTBY Oct 13, 2022
42d78be
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 13, 2022
d2a5e02
fix: revise the nits and examples in contrastive_search docstring.
gmftbyGMFTBY Oct 13, 2022
1d4f782
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 14, 2022
d5f90fb
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 14, 2022
d5d30b7
fix the copyright
gmftbyGMFTBY Oct 14, 2022
49000c6
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 17, 2022
3058e1c
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 17, 2022
c344a0a
delete generation_contrastive_search.py
gmftbyGMFTBY Oct 17, 2022
628ecda
Merge branch 'csearch-pr-v2' of https://github.com/gmftbyGMFTBY/trans…
gmftbyGMFTBY Oct 17, 2022
5ae4ce2
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 18, 2022
7af4cbb
revise the logic in contrastive_search
gmftbyGMFTBY Oct 18, 2022
b219a17
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 18, 2022
183d7cc
update the intergration test and the docstring
gmftbyGMFTBY Oct 18, 2022
65a1ebd
run the tests over
gmftbyGMFTBY Oct 18, 2022
4972bfb
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 18, 2022
e11d342
add the slow decorate to the contrastive_search intergrate test
gmftbyGMFTBY Oct 18, 2022
da014bb
Merge branch 'huggingface:main' into csearch-pr-v2
gmftbyGMFTBY Oct 19, 2022
2aa768c
add more test
gmftbyGMFTBY Oct 19, 2022
ced9f70
do the style, quality, consistency checks
gmftbyGMFTBY Oct 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/en/internal/generation_utils.mdx
Expand Up @@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.

This page lists all the utility functions used by [`~generation_utils.GenerationMixin.generate`],
[`~generation_utils.GenerationMixin.greedy_search`],
[`~generation_utils.GenerationMixin.contrastive_search`],
[`~generation_utils.GenerationMixin.sample`],
[`~generation_utils.GenerationMixin.beam_search`],
[`~generation_utils.GenerationMixin.beam_sample`],
Expand Down Expand Up @@ -261,3 +262,5 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] top_k_top_p_filtering

[[autodoc]] tf_top_k_top_p_filtering

[[autodoc]] ContrastiveDecodingOneStepFast
1 change: 1 addition & 0 deletions docs/source/en/main_classes/text_generation.mdx
Expand Up @@ -26,6 +26,7 @@ Each framework has a generate method for auto-regressive text generation impleme
- sample
- beam_search
- beam_sample
- contrastive_search
- group_beam_search
- constrained_beam_search

Expand Down
138 changes: 138 additions & 0 deletions examples/pytorch/text-generation/run_generation_contrastive_search.py
@@ -0,0 +1,138 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" The examples of running contrastive search on the auto-APIs;

Running this example:
python run_generation_contrastive_search.py --model_name_or_path=gpt2-large --penalty_alpha=0.6 --k=4 --length=256
"""


import argparse
import logging

import numpy as np
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer


logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)


def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
)
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--length", type=int, default=20)
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
)
parser.add_argument(
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
)
parser.add_argument("--k", type=int, default=0)
parser.add_argument("--penalty_alpha", type=float, default=0.0)
parser.add_argument("--p", type=float, default=0.9)

parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")

parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
args = parser.parse_args()

args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()

logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}")

set_seed(args)

# Initialize the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)

# tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
# model = OPTForCausalLM.from_pretrained(args.model_name_or_path)
model.to(args.device)

if args.fp16:
model.half()

logger.info(args)
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")

inputs = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False)
inputs = {key: value.to(args.device) for key, value in inputs.items()}

output_sequences = model.generate(
**inputs,
max_length=args.length + len(inputs["input_ids"][0]),
penalty_alpha=args.penalty_alpha,
top_k=args.k,
)

generated_sequences = []
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
generated_sequence = generated_sequence.tolist()

# Decode text
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, add_special_tokens=False)

# Remove all text after the stop token
text = text[: text.find(args.stop_token) if args.stop_token else None]

# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
total_sequence = (
prompt_text + text[len(tokenizer.decode(inputs["input_ids"][0], clean_up_tokenization_spaces=True)) :]
)

generated_sequences.append(total_sequence)
print(total_sequence)

return generated_sequences


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -802,6 +802,7 @@
"PhrasalConstraint",
]
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"]
_import_structure["generation_contrastive_search"] = ["ContrastiveDecodingOneStepFast"]
_import_structure["generation_logits_process"] = [
"ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor",
Expand Down Expand Up @@ -3750,6 +3751,7 @@
PhrasalConstraint,
)
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .generation_contrastive_search import ContrastiveDecodingOneStepFast
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would maybe not put this in the public init as it could be prone to change in the future and I don't think most people will use it outside of generate no?

from .generation_logits_process import (
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
Expand Down
196 changes: 196 additions & 0 deletions src/transformers/generation_contrastive_search.py
@@ -0,0 +1,196 @@
# coding=utf-8
# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains the utils functions for the contrastive search, which will be called in `generation_utils`
"""

from typing import Tuple

import torch
from torch import nn

from .utils import logging


logger = logging.get_logger(__name__)


def ranking_fast(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible I'd be in favor of not adding a new file here. IMO ranking_fast can be changed to a logit processor and ContrastiveDecodingOneStepFast is could be fully moved into generation_utils.py right away.

IMO this will be easier to maintain and understand for people that know already how generate works - wdyt @gante ?

context_hidden: torch.FloatTensor,
next_hidden: torch.FloatTensor,
next_top_k_probs: torch.FloatTensor,
alpha: float,
beam_width: int,
) -> Tuple[torch.FloatTensor]:
"""
context_hidden: bsz*beam x seqlen x embed_dim next_hidden: bsz*beam x 1 x embed_dim next_top_k_probs: bsz x beam
"""
_, context_len, embed_dim = context_hidden.size()
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S]
scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
scores = (1.0 - alpha) * next_top_k_probs - alpha * scores
scores = torch.stack(torch.split(scores, beam_width)) # [B, K]
selected_scores, selected_idx = scores.max(dim=-1) # [B]
return selected_scores, selected_idx


def ContrastiveDecodingOneStepFast(
model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not very happy about passing the model into the function here and upper-casing the function.

Could we try to have this whole functionality directly inside the contrastive_search function? It would help readability a lot IMO.

@gante - what do you think?

beam_width: int = 1,
penalty_alpha: float = 0.0,
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
last_hidden_states: torch.FloatTensor = None,
logit_for_next_step: torch.FloatTensor = None,
is_encoder_decoder: bool = False,
**model_inputs,
) -> Tuple:
"""
contrastive search first selects top-k candidates by the logit scores; then these candidate tokens are fed into the
language models to compute the degeneration penalty, which will be used to re-rank these candidate tokens.

Args:
model ([`PretrainedModel`]): The decoder-only or encoder-decoder generation language models.
beam_width (`int`, *optional*, defauls to 1):
If > 1, top k candidate tokens are selected for re-ranking.
penalty_alpha (`float`, *optional*, defaults to 0.0):
If > 0, the model confidence will minus the degeneration penalty that is weighted by the penalty_alpha
parameter.
past_key_value (`Tuple[Tuple[torch.FloatTensor]]`):
It saves the cached key-value results that computed by previous steps.
last_hidden_states (`torch.FloatTensor`, *optional*):
The last_hidden_states generated by the previous step.
logit_for_next_step (`torch.FloatTensor`, *optional*):
The logit_for_next_step generated by the previous step.
is_encoder_decoder (`bool`, *optional*, defaults to `False`):
If `True`, the model is an encoder-decode model, otherwise the model is a decoder-only model.

From: https://github.com/yxuansu/SimCTG/blob/main/contrastive_search_explanation/README.md
"""
bsz, seqlen, embed_dim = last_hidden_states.size()
next_probs = nn.functional.softmax(logit_for_next_step, dim=-1)
_, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width)
top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids)
past_key_values = enlarge_past_key_values(past_key_values, beam_width)

# build next attention mask
if "attention_mask" in model_inputs:
attention_mask = model_inputs["attention_mask"] # [B, S]
# decoder-only model need the full attention mask, not only the mask for the last token
if is_encoder_decoder is False:
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((bsz, 1))], dim=-1)
attention_mask = attention_mask.unsqueeze(1).expand(-1, beam_width, -1).reshape(-1, attention_mask.size(-1))
else:
attention_mask = None

# encoder-decoder model also contains the `encoder_outputs`
if is_encoder_decoder and "encoder_outputs" in model_inputs:
encoder_outputs = model_inputs["encoder_outputs"]
else:
encoder_outputs = None
next_model_inputs = model.prepare_inputs_for_generation(
top_k_ids.view(-1, 1),
past=past_key_values,
attention_mask=attention_mask,
use_cache=True,
encoder_outputs=encoder_outputs,
)
# compute the candidate tokens by the language model and collects their hidden_states
output = model(output_hidden_states=True, **next_model_inputs)
past_key_values = output.past_key_values
logits = output.logits[:, -1, :]
# name is different for encoder-decoder and decoder-only models
if is_encoder_decoder:
next_hidden = output.decoder_hidden_states[-1]
full_hidden_states = output.decoder_hidden_states
else:
next_hidden = output.hidden_states[-1]
full_hidden_states = output.hidden_states
context_hidden = (
last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz * beam_width, seqlen, embed_dim)
)

# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the model confidence
# the scores and index of the selected tokens are returned
selected_scores, selected_idx = ranking_fast(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gante I think we can/should make this a logit processor that exceptionally takes context_hidden and next_hidden as inputs arguments as well . top_k_probs can be computed inside the logit processor, penalty_alpha and beam_width can be stored inside the logic processor at init

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the logic processor could then (just like all other processors) return the scores and then we compute selected_idx after

context_hidden,
next_hidden,
top_k_probs,
penalty_alpha,
beam_width,
)
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores (model confidence minus degeneration penalty); (6) decoder hidden_states
next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1)
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width))
next_hidden = next_hidden[range(bsz), selected_idx, :]
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)

decoder_hidden_states = []
for layer in full_hidden_states:
layer = torch.stack(torch.split(layer.squeeze(dim=1), beam_width))
layer = layer[range(bsz), selected_idx, :]
decoder_hidden_states.append(layer)

past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx)
logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :]
return next_id.squeeze(dim=-1), past_key_values, last_hidden_states, logits, selected_scores, decoder_hidden_states


def enlarge_past_key_values(
past_key_values: Tuple[Tuple[torch.FloatTensor]], beam_width: int
) -> Tuple[Tuple[torch.FloatTensor]]:
"""
Copy and extend the past_key_values for the next step re-rank each item in `past_key_values` is the 4-dimension
matrix, whose shapre is [batch_size, num_head, seq_len, embed_dim] Suppose the size of the next token candidate
size is K, we need to obtain the new `past_key_values`, whose shape is [batch_size*K, num_head, seq_len, embed_dim]
"""
# from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz]
new_key_values = []
for layer in past_key_values:
items = []
# item is either the key or the value matrix
for item in layer:
bsz, num_head, seq_len, esz = item.size()
item = (
item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz * beam_width, num_head, seq_len, esz)
) # [bsz*beam, num_head, seq_len, esz]
items.append(item)
new_key_values.append(items)
return new_key_values


def select_past_key_values(
past_key_values: Tuple[Tuple[torch.FloatTensor]], beam_width: int, selected_idx: torch.FloatTensor
) -> Tuple[Tuple[torch.FloatTensor]]:
"""
Extract the `past_key_value` for the selected tokens, each item in `past_key_value` is the 4-dimension matrix,
whose shape is [batch_size*K, num_head, seq_len, embed_dim], where K is the number of the candidate tokens. We aim
to obtain the `past_key_value` of the selected next token, whose shape is [batch_size, num_head, seq_len,
embed_dim]
"""
new_key_values = []
for layer in past_key_values:
items = []
# item is either the key or the value matrix
for item in layer:
bsz_and_beam, num_head, seq_len, esz = item.size()
bsz = int(bsz_and_beam // beam_width)
item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz]
item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz]
items.append(item)
new_key_values.append(items)
return new_key_values