New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding the state-of-the-art contrastive search decoding methods for the codebase of generation_utils.py #19477
Changes from 25 commits
942a7c6
5909423
3e71819
47b2b3b
41e37a5
e278a46
38b100f
9abd1bb
32e2a30
e9e2b26
1f1dac2
6226b9a
f3bfd87
d3a91b8
ce26f9f
e801c6f
c78cf91
fb4174e
68429ad
e1f0db9
42d78be
d2a5e02
1d4f782
d5f90fb
d5d30b7
49000c6
3058e1c
c344a0a
628ecda
5ae4ce2
7af4cbb
b219a17
183d7cc
65a1ebd
4972bfb
e11d342
da014bb
2aa768c
ced9f70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" The examples of running contrastive search on the auto-APIs; | ||
|
||
Running this example: | ||
python run_generation_contrastive_search.py --model_name_or_path=gpt2-large --penalty_alpha=0.6 --k=4 --length=256 | ||
""" | ||
|
||
|
||
import argparse | ||
import logging | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
|
||
logging.basicConfig( | ||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | ||
datefmt="%m/%d/%Y %H:%M:%S", | ||
level=logging.INFO, | ||
) | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def set_seed(args): | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
if args.n_gpu > 0: | ||
torch.cuda.manual_seed_all(args.seed) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--model_name_or_path", | ||
default=None, | ||
type=str, | ||
required=True, | ||
) | ||
parser.add_argument("--prompt", type=str, default="") | ||
parser.add_argument("--length", type=int, default=20) | ||
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") | ||
parser.add_argument( | ||
"--temperature", | ||
type=float, | ||
default=1.0, | ||
help="temperature of 1.0 has no effect, lower tend toward greedy sampling", | ||
) | ||
parser.add_argument( | ||
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2" | ||
) | ||
parser.add_argument("--k", type=int, default=0) | ||
parser.add_argument("--penalty_alpha", type=float, default=0.0) | ||
parser.add_argument("--p", type=float, default=0.9) | ||
|
||
parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.") | ||
parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.") | ||
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.") | ||
|
||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") | ||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") | ||
parser.add_argument( | ||
"--fp16", | ||
action="store_true", | ||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", | ||
) | ||
args = parser.parse_args() | ||
|
||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") | ||
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() | ||
|
||
logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}") | ||
|
||
set_seed(args) | ||
|
||
# Initialize the model and tokenizer | ||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) | ||
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path) | ||
|
||
# tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path) | ||
# model = OPTForCausalLM.from_pretrained(args.model_name_or_path) | ||
model.to(args.device) | ||
|
||
if args.fp16: | ||
model.half() | ||
|
||
logger.info(args) | ||
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") | ||
|
||
inputs = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False) | ||
inputs = {key: value.to(args.device) for key, value in inputs.items()} | ||
|
||
output_sequences = model.generate( | ||
**inputs, | ||
max_length=args.length + len(inputs["input_ids"][0]), | ||
penalty_alpha=args.penalty_alpha, | ||
top_k=args.k, | ||
) | ||
|
||
generated_sequences = [] | ||
for generated_sequence_idx, generated_sequence in enumerate(output_sequences): | ||
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") | ||
generated_sequence = generated_sequence.tolist() | ||
|
||
# Decode text | ||
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, add_special_tokens=False) | ||
|
||
# Remove all text after the stop token | ||
text = text[: text.find(args.stop_token) if args.stop_token else None] | ||
|
||
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing | ||
total_sequence = ( | ||
prompt_text + text[len(tokenizer.decode(inputs["input_ids"][0], clean_up_tokenization_spaces=True)) :] | ||
) | ||
|
||
generated_sequences.append(total_sequence) | ||
print(total_sequence) | ||
|
||
return generated_sequences | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
# coding=utf-8 | ||
# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
This file contains the utils functions for the contrastive search, which will be called in `generation_utils` | ||
""" | ||
|
||
from typing import Tuple | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from .utils import logging | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
def ranking_fast( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If possible I'd be in favor of not adding a new file here. IMO IMO this will be easier to maintain and understand for people that know already how |
||
context_hidden: torch.FloatTensor, | ||
next_hidden: torch.FloatTensor, | ||
next_top_k_probs: torch.FloatTensor, | ||
alpha: float, | ||
beam_width: int, | ||
) -> Tuple[torch.FloatTensor]: | ||
""" | ||
context_hidden: bsz*beam x seqlen x embed_dim next_hidden: bsz*beam x 1 x embed_dim next_top_k_probs: bsz x beam | ||
""" | ||
_, context_len, embed_dim = context_hidden.size() | ||
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) | ||
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) | ||
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] | ||
scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K] | ||
next_top_k_probs = next_top_k_probs.view(-1) # [B*K] | ||
scores = (1.0 - alpha) * next_top_k_probs - alpha * scores | ||
scores = torch.stack(torch.split(scores, beam_width)) # [B, K] | ||
selected_scores, selected_idx = scores.max(dim=-1) # [B] | ||
return selected_scores, selected_idx | ||
|
||
|
||
def ContrastiveDecodingOneStepFast( | ||
model, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not very happy about passing the Could we try to have this whole functionality directly inside the @gante - what do you think? |
||
beam_width: int = 1, | ||
penalty_alpha: float = 0.0, | ||
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, | ||
last_hidden_states: torch.FloatTensor = None, | ||
logit_for_next_step: torch.FloatTensor = None, | ||
is_encoder_decoder: bool = False, | ||
**model_inputs, | ||
) -> Tuple: | ||
""" | ||
contrastive search first selects top-k candidates by the logit scores; then these candidate tokens are fed into the | ||
language models to compute the degeneration penalty, which will be used to re-rank these candidate tokens. | ||
|
||
Args: | ||
model ([`PretrainedModel`]): The decoder-only or encoder-decoder generation language models. | ||
beam_width (`int`, *optional*, defauls to 1): | ||
If > 1, top k candidate tokens are selected for re-ranking. | ||
penalty_alpha (`float`, *optional*, defaults to 0.0): | ||
If > 0, the model confidence will minus the degeneration penalty that is weighted by the penalty_alpha | ||
parameter. | ||
past_key_value (`Tuple[Tuple[torch.FloatTensor]]`): | ||
It saves the cached key-value results that computed by previous steps. | ||
last_hidden_states (`torch.FloatTensor`, *optional*): | ||
The last_hidden_states generated by the previous step. | ||
logit_for_next_step (`torch.FloatTensor`, *optional*): | ||
The logit_for_next_step generated by the previous step. | ||
is_encoder_decoder (`bool`, *optional*, defaults to `False`): | ||
If `True`, the model is an encoder-decode model, otherwise the model is a decoder-only model. | ||
|
||
From: https://github.com/yxuansu/SimCTG/blob/main/contrastive_search_explanation/README.md | ||
""" | ||
bsz, seqlen, embed_dim = last_hidden_states.size() | ||
next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) | ||
_, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) | ||
top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) | ||
past_key_values = enlarge_past_key_values(past_key_values, beam_width) | ||
|
||
# build next attention mask | ||
if "attention_mask" in model_inputs: | ||
attention_mask = model_inputs["attention_mask"] # [B, S] | ||
# decoder-only model need the full attention mask, not only the mask for the last token | ||
if is_encoder_decoder is False: | ||
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((bsz, 1))], dim=-1) | ||
attention_mask = attention_mask.unsqueeze(1).expand(-1, beam_width, -1).reshape(-1, attention_mask.size(-1)) | ||
else: | ||
attention_mask = None | ||
|
||
# encoder-decoder model also contains the `encoder_outputs` | ||
if is_encoder_decoder and "encoder_outputs" in model_inputs: | ||
encoder_outputs = model_inputs["encoder_outputs"] | ||
else: | ||
encoder_outputs = None | ||
next_model_inputs = model.prepare_inputs_for_generation( | ||
top_k_ids.view(-1, 1), | ||
past=past_key_values, | ||
attention_mask=attention_mask, | ||
use_cache=True, | ||
encoder_outputs=encoder_outputs, | ||
) | ||
# compute the candidate tokens by the language model and collects their hidden_states | ||
output = model(output_hidden_states=True, **next_model_inputs) | ||
past_key_values = output.past_key_values | ||
logits = output.logits[:, -1, :] | ||
# name is different for encoder-decoder and decoder-only models | ||
if is_encoder_decoder: | ||
next_hidden = output.decoder_hidden_states[-1] | ||
full_hidden_states = output.decoder_hidden_states | ||
else: | ||
next_hidden = output.hidden_states[-1] | ||
full_hidden_states = output.hidden_states | ||
context_hidden = ( | ||
last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz * beam_width, seqlen, embed_dim) | ||
) | ||
|
||
# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the model confidence | ||
# the scores and index of the selected tokens are returned | ||
selected_scores, selected_idx = ranking_fast( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gante I think we can/should make this a logit processor that exceptionally takes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO the logic processor could then (just like all other processors) return the |
||
context_hidden, | ||
next_hidden, | ||
top_k_probs, | ||
penalty_alpha, | ||
beam_width, | ||
) | ||
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores (model confidence minus degeneration penalty); (6) decoder hidden_states | ||
next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) | ||
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) | ||
next_hidden = next_hidden[range(bsz), selected_idx, :] | ||
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) | ||
|
||
decoder_hidden_states = [] | ||
for layer in full_hidden_states: | ||
layer = torch.stack(torch.split(layer.squeeze(dim=1), beam_width)) | ||
layer = layer[range(bsz), selected_idx, :] | ||
decoder_hidden_states.append(layer) | ||
|
||
past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx) | ||
logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] | ||
return next_id.squeeze(dim=-1), past_key_values, last_hidden_states, logits, selected_scores, decoder_hidden_states | ||
|
||
|
||
def enlarge_past_key_values( | ||
past_key_values: Tuple[Tuple[torch.FloatTensor]], beam_width: int | ||
) -> Tuple[Tuple[torch.FloatTensor]]: | ||
""" | ||
Copy and extend the past_key_values for the next step re-rank each item in `past_key_values` is the 4-dimension | ||
matrix, whose shapre is [batch_size, num_head, seq_len, embed_dim] Suppose the size of the next token candidate | ||
size is K, we need to obtain the new `past_key_values`, whose shape is [batch_size*K, num_head, seq_len, embed_dim] | ||
""" | ||
# from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz] | ||
new_key_values = [] | ||
for layer in past_key_values: | ||
items = [] | ||
# item is either the key or the value matrix | ||
for item in layer: | ||
bsz, num_head, seq_len, esz = item.size() | ||
item = ( | ||
item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz * beam_width, num_head, seq_len, esz) | ||
) # [bsz*beam, num_head, seq_len, esz] | ||
items.append(item) | ||
new_key_values.append(items) | ||
return new_key_values | ||
|
||
|
||
def select_past_key_values( | ||
past_key_values: Tuple[Tuple[torch.FloatTensor]], beam_width: int, selected_idx: torch.FloatTensor | ||
) -> Tuple[Tuple[torch.FloatTensor]]: | ||
""" | ||
Extract the `past_key_value` for the selected tokens, each item in `past_key_value` is the 4-dimension matrix, | ||
whose shape is [batch_size*K, num_head, seq_len, embed_dim], where K is the number of the candidate tokens. We aim | ||
to obtain the `past_key_value` of the selected next token, whose shape is [batch_size, num_head, seq_len, | ||
embed_dim] | ||
""" | ||
new_key_values = [] | ||
for layer in past_key_values: | ||
items = [] | ||
# item is either the key or the value matrix | ||
for item in layer: | ||
bsz_and_beam, num_head, seq_len, esz = item.size() | ||
bsz = int(bsz_and_beam // beam_width) | ||
item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz] | ||
item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz] | ||
items.append(item) | ||
new_key_values.append(items) | ||
return new_key_values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would maybe not put this in the public init as it could be prone to change in the future and I don't think most people will use it outside of
generate
no?