From 942a7c63e1bc3a57fac0c1485a89353612f11355 Mon Sep 17 00:00:00 2001 From: gmftbyGMFTBY <18811371908@163.com> Date: Mon, 10 Oct 2022 18:53:54 +0800 Subject: [PATCH 01/17] add: the contrastive search for generaton_utils --- src/transformers/generation_utils.py | 407 ++++++++++++++++++++++++++- 1 file changed, 402 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 380eec07270c9..cf00c6b0d092a 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1328,19 +1328,24 @@ def generate( # 6. determine generation mode is_constraint_gen_mode = constraints is not None or force_words_ids is not None + + is_contrastive_search_gen_mode = ( + top_k is not None and top_k > 1 and do_sample is False and penalty_alpha is not None and penalty_alpha > 0 + ) + is_greedy_gen_mode = ( - (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode + (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_sample_gen_mode = ( - (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode + (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_beam_gen_mode = ( - (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode + (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_beam_sample_gen_mode = ( - (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode + (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) - is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode + is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode and not is_contrastive_search_gen_mode if num_beam_groups > num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") @@ -1399,6 +1404,27 @@ def generate( **model_kwargs, ) + elif is_contrastive_search_gen_mode: + + if num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." + ) + + return self.contrastive_search( + input_ids, + top_k=top_k, + penalty_alpha=penalty_alpha, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_sample_gen_mode: # 10. prepare logits warper logits_warper = self._get_logits_warper( @@ -1634,6 +1660,228 @@ def typeerror(): **model_kwargs, ) + @torch.no_grad() + def contrastive_search( + self, + input_ids: torch.LongTensor, + top_k: Optional[int] = 1, + penalty_alpha: Optional[float] = 0, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[GreedySearchOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **contrastive search** and can + be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + top_k (`int`, *optional*, defaults to 1): + The size of the candidate set that is used to re-rank for contrastive search + penalty_alpha (`float`, *optional*, defaults to 0): + The degeneration penalty for contrastive search; activate when it is larger than 0 + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + Return: + [`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`] + or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + Examples: + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token + >>> model.config.pad_token_id = model.config.eos_token_id + >>> input_prompt = "It might be possible to" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + >>> outputs = model.contrastive_search(input_ids, stopping_criteria=stopping_criteria) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ["It might be possible to get a better understanding of the nature of the problem, but it's not"] + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + this_peer_finished = False # used by synced_gpus only + + # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save the `encoder_outputs` + model_kwargs["use_cache"] = True + model_kwargs["past_key_values"] = None + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + output = self(**model_inputs, output_hidden_states=True, output_attentions=True) + + # past_key_values is activated for fast decoding + past_key_values = output.past_key_values + model_inputs["past_key_values"] = past_key_values + + # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with previous tokens) + if self.config.is_encoder_decoder: + last_hidden_states = output.decoder_hidden_states[-1] + else: + last_hidden_states = output.hidden_states[-1] + # next logit for contrastive search to select top-k candidate tokens + logit_for_next_step = output.logits[:, -1, :] + + while True: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by degeneration penalty + ( + next_tokens, + past_key_values, + last_hidden_states, + logit_for_next_step, + selected_scores, + decoder_hidden_states_one_step, + ) = ContrastiveDecodingOneStepFast( + self, + beam_width=top_k, + penalty_alpha=penalty_alpha, + last_hidden_states=last_hidden_states, + logit_for_next_step=logit_for_next_step, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_inputs, + ) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (selected_scores,) + + if output_hidden_states: + decoder_hidden_states += (decoder_hidden_states_one_step,) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if self.config.is_encoder_decoder: + outputs = Seq2SeqLMOutput( + past_key_values=past_key_values, + ) + else: + outputs = CausalLMOutputWithCrossAttentions( + past_key_values=past_key_values, attentions=model_kwargs["attention_mask"] + ) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + # prepare model inputs + model_kwargs["past_key_values"] = past_key_values + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return ContrastiveSearchEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return ContrastiveSearchDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + hidden_states=decoder_hidden_states, + ) + else: + return input_ids + def greedy_search( self, input_ids: torch.LongTensor, @@ -3446,3 +3694,152 @@ def top_k_top_p_filtering( ) return logits + + +# ========== utils for contrastive search decoding method ========= # +def ranking_fast( + context_hidden: torch.FloatTensor, + next_hidden: torch.FloatTensor, + next_top_k_probs: torch.FloatTensor, + alpha: float, + beam_width: int, +) -> Tuple[torch.FloatTensor]: + """ + context_hidden: bsz*beam x seqlen x embed_dim next_hidden: bsz*beam x 1 x embed_dim next_top_k_probs: bsz x beam + """ + _, context_len, embed_dim = context_hidden.size() + norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) + norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) + cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] + scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K] + next_top_k_probs = next_top_k_probs.view(-1) # [B*K] + scores = (1.0 - alpha) * next_top_k_probs - alpha * scores + scores = torch.stack(torch.split(scores, beam_width)) # [B, K] + selected_scores, selected_idx = scores.max(dim=-1) # [B] + return selected_scores, selected_idx + + +def ContrastiveDecodingOneStepFast( + model, + beam_width: int = 1, + penalty_alpha: float = 0.0, + past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, + last_hidden_states: torch.FloatTensor = None, + logit_for_next_step: torch.FloatTensor = None, + is_encoder_decoder: bool = False, + **model_inputs, +) -> Tuple: + """ + contrastive search first selects top-k candidates by the logit scores; then these candidate tokens are fed into the + language models to compute the degeneration penalty, which will be used to re-rank these candidate tokens. + """ + bsz, seqlen, embed_dim = last_hidden_states.size() + next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) + _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) + top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) + past_key_values = enlarge_past_key_values(past_key_values, beam_width) + + # build next attention mask + attention_mask = model_inputs["attention_mask"] # [B, S] + # decoder-only model need the full attention mask, not only the mask for the last token + if is_encoder_decoder is False: + attention_mask = torch.cat([attention_mask, attention_mask.new_ones((bsz, 1))], dim=-1) + attention_mask = attention_mask.unsqueeze(1).expand(-1, beam_width, -1).reshape(-1, attention_mask.size(-1)) + + # encoder-decoder model also contains the `encoder_outputs` + if is_encoder_decoder and "encoder_outputs" in model_inputs: + encoder_outputs = model_inputs["encoder_outputs"] + else: + encoder_outputs = None + next_model_inputs = model.prepare_inputs_for_generation( + top_k_ids.view(-1, 1), + past=past_key_values, + attention_mask=attention_mask, + use_cache=True, + encoder_outputs=encoder_outputs, + ) + # compute the candidate tokens by the language model and collects their hidden_states + output = model(output_hidden_states=True, **next_model_inputs) + past_key_values = output.past_key_values + logits = output.logits[:, -1, :] + # name is different for encoder-decoder and decoder-only models + if is_encoder_decoder: + next_hidden = output.decoder_hidden_states[-1] + full_hidden_states = output.decoder_hidden_states + else: + next_hidden = output.hidden_states[-1] + full_hidden_states = output.hidden_states + context_hidden = ( + last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz * beam_width, seqlen, embed_dim) + ) + + # compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the model confidence + # the scores and index of the selected tokens are returned + selected_scores, selected_idx = ranking_fast( + context_hidden, + next_hidden, + top_k_probs, + penalty_alpha, + beam_width, + ) + # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores (model confidence minus degeneration penalty); (6) decoder hidden_states + next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) + next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) + next_hidden = next_hidden[range(bsz), selected_idx, :] + last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) + + decoder_hidden_states = [] + for layer in full_hidden_states: + layer = torch.stack(torch.split(layer.squeeze(dim=1), beam_width)) + layer = layer[range(bsz), selected_idx, :] + decoder_hidden_states.append(layer) + + past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx) + logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] + return next_id.squeeze(dim=-1), past_key_values, last_hidden_states, logits, selected_scores, decoder_hidden_states + + +def enlarge_past_key_values( + past_key_values: Tuple[Tuple[torch.FloatTensor]], beam_width: int +) -> Tuple[Tuple[torch.FloatTensor]]: + """ + Copy and extend the past_key_values for the next step re-rank each item in `past_key_values` is the 4-dimension + matrix, whose shapre is [batch_size, num_head, seq_len, embed_dim] Suppose the size of the next token candidate + size is K, we need to obtain the new `past_key_values`, whose shape is [batch_size*K, num_head, seq_len, embed_dim] + """ + # from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz] + new_key_values = [] + for layer in past_key_values: + items = [] + # item is either the key or the value matrix + for item in layer: + bsz, num_head, seq_len, esz = item.size() + item = ( + item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz * beam_width, num_head, seq_len, esz) + ) # [bsz*beam, num_head, seq_len, esz] + items.append(item) + new_key_values.append(items) + return new_key_values + + +def select_past_key_values( + past_key_values: Tuple[Tuple[torch.FloatTensor]], beam_width: int, selected_idx: torch.FloatTensor +) -> Tuple[Tuple[torch.FloatTensor]]: + """ + Extract the `past_key_value` for the selected tokens, each item in `past_key_value` is the 4-dimension matrix, + whose shape is [batch_size*K, num_head, seq_len, embed_dim], where K is the number of the candidate tokens. We aim + to obtain the `past_key_value` of the selected next token, whose shape is [batch_size, num_head, seq_len, + embed_dim] + """ + new_key_values = [] + for layer in past_key_values: + items = [] + # item is either the key or the value matrix + for item in layer: + bsz_and_beam, num_head, seq_len, esz = item.size() + bsz = int(bsz_and_beam // beam_width) + item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz] + item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz] + items.append(item) + new_key_values.append(items) + return new_key_values From 3e71819561f189636a1903368efb0fbb9d86826f Mon Sep 17 00:00:00 2001 From: gmftbyGMFTBY <18811371908@163.com> Date: Mon, 10 Oct 2022 23:33:51 +0800 Subject: [PATCH 02/17] add: testing scripts for contrastive search under examples/text-generation --- .../run_generation_contrastive_search.py | 278 ++++++++++++++++++ src/transformers/generation_utils.py | 80 ++++- 2 files changed, 352 insertions(+), 6 deletions(-) create mode 100755 examples/pytorch/text-generation/run_generation_contrastive_search.py diff --git a/examples/pytorch/text-generation/run_generation_contrastive_search.py b/examples/pytorch/text-generation/run_generation_contrastive_search.py new file mode 100755 index 0000000000000..3a9f7539df906 --- /dev/null +++ b/examples/pytorch/text-generation/run_generation_contrastive_search.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet) +""" + + +import argparse +import logging + +import numpy as np +import torch + +from transformers import ( + CTRLLMHeadModel, + CTRLTokenizer, + GPT2LMHeadModel, + GPT2Tokenizer, + OpenAIGPTLMHeadModel, + OpenAIGPTTokenizer, + TransfoXLLMHeadModel, + TransfoXLTokenizer, + XLMTokenizer, + XLMWithLMHeadModel, + XLNetLMHeadModel, + XLNetTokenizer, +) + + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + +MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop + +MODEL_CLASSES = { + "gpt2": (GPT2LMHeadModel, GPT2Tokenizer), + "ctrl": (CTRLLMHeadModel, CTRLTokenizer), + "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), + "xlnet": (XLNetLMHeadModel, XLNetTokenizer), + "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), + "xlm": (XLMWithLMHeadModel, XLMTokenizer), +} + +# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia +# in https://github.com/rusiaaman/XLNet-gen#methodology +# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e +PREFIX = """DeepMind Company is""" + + +def set_seed(args): + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +# +# Functions to prepare models' input +# + + +def prepare_ctrl_input(args, _, tokenizer, prompt_text): + if args.temperature > 0.7: + logger.info("CTRL typically works better with lower temperatures (and lower top_k).") + + encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False) + if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()): + logger.info("WARNING! You are not starting your generation from a control code so you won't get good results") + return prompt_text + + +def prepare_xlm_input(args, model, tokenizer, prompt_text): + # kwargs = {"language": None, "mask_token_id": None} + + # Set the language + use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb + if hasattr(model.config, "lang2id") and use_lang_emb: + available_languages = model.config.lang2id.keys() + if args.xlm_language in available_languages: + language = args.xlm_language + else: + language = None + while language not in available_languages: + language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ") + + model.config.lang_id = model.config.lang2id[language] + # kwargs["language"] = tokenizer.lang2id[language] + + # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers + # XLM masked-language modeling (MLM) models need masked token + # is_xlm_mlm = "mlm" in args.model_name_or_path + # if is_xlm_mlm: + # kwargs["mask_token_id"] = tokenizer.mask_token_id + + return prompt_text + + +def prepare_xlnet_input(args, _, tokenizer, prompt_text): + prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX + prompt_text = prefix + prompt_text + return prompt_text + + +def prepare_transfoxl_input(args, _, tokenizer, prompt_text): + prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX + prompt_text = prefix + prompt_text + return prompt_text + + +PREPROCESSING_FUNCTIONS = { + "ctrl": prepare_ctrl_input, + "xlm": prepare_xlm_input, + "xlnet": prepare_xlnet_input, + "transfo-xl": prepare_transfoxl_input, +} + + +def adjust_length_to_model(length, max_sequence_length): + if length < 0 and max_sequence_length > 0: + length = max_sequence_length + elif 0 < max_sequence_length < length: + length = max_sequence_length # No generation bigger than model size + elif length < 0: + length = MAX_LENGTH # avoid infinite loop + return length + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), + ) + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()), + ) + + parser.add_argument("--prompt", type=str, default="") + parser.add_argument("--length", type=int, default=20) + parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") + + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="temperature of 1.0 has no effect, lower tend toward greedy sampling", + ) + parser.add_argument( + "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2" + ) + parser.add_argument("--k", type=int, default=0) + parser.add_argument("--p", type=float, default=0.9) + + parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.") + parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.") + parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.") + + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") + parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") + parser.add_argument( + "--fp16", + action="store_true", + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", + ) + args = parser.parse_args() + + args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() + + logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}") + + set_seed(args) + + # Initialize the model and tokenizer + try: + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + except KeyError: + raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)") + + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + model = model_class.from_pretrained(args.model_name_or_path) + model.to(args.device) + + if args.fp16: + model.half() + + args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings) + logger.info(args) + + prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") + + # Different models need different input formatting and/or extra arguments + requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys() + if requires_preprocessing: + prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) + preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) + + if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: + tokenizer_kwargs = {"add_space_before_punct_symbol": True} + else: + tokenizer_kwargs = {} + + encoded_prompt = tokenizer.encode( + preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs + ) + else: + prefix = args.prefix if args.prefix else args.padding_text + encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt") + encoded_prompt = encoded_prompt.to(args.device) + + if encoded_prompt.size()[-1] == 0: + input_ids = None + else: + input_ids = encoded_prompt + + output_sequences = model.generate( + input_ids=input_ids, + max_length=args.length + len(encoded_prompt[0]), + num_return_sequences=args.num_return_sequences, + penalty_alpha=0.6, + top_k=4, + ) + + # Remove the batch dimension when returning multiple sequences + if len(output_sequences.shape) > 2: + output_sequences.squeeze_() + + generated_sequences = [] + + for generated_sequence_idx, generated_sequence in enumerate(output_sequences): + print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") + generated_sequence = generated_sequence.tolist() + + # Decode text + text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) + + # Remove all text after the stop token + text = text[: text.find(args.stop_token) if args.stop_token else None] + + # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing + total_sequence = ( + prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] + ) + + generated_sequences.append(total_sequence) + print(total_sequence) + + return generated_sequences + + +if __name__ == "__main__": + main() diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index cf00c6b0d092a..ae2dbf59464d1 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -54,6 +54,7 @@ StoppingCriteriaList, validate_stopping_criteria, ) +from .modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput from .models.auto import ( MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, @@ -96,6 +97,50 @@ class GreedySearchDecoderOnlyOutput(ModelOutput): hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None +@dataclass +class ContrastiveSearchEncoderDecoderOutput(ModelOutput): + """ + Args: + Base class for outputs of decoder-only generation models using contrastive search. + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + scores: Optional[Tuple[torch.FloatTensor]] = None + decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +@dataclass +class ContrastiveSearchDecoderOnlyOutput(ModelOutput): + """ + Args: + Base class for outputs of decoder-only generation models using contrastive search. + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + scores: Optional[Tuple[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + @dataclass class GreedySearchEncoderDecoderOutput(ModelOutput): """ @@ -921,6 +966,7 @@ def generate( early_stopping: Optional[bool] = None, num_beams: Optional[int] = None, temperature: Optional[float] = None, + penalty_alpha: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, typical_p: Optional[float] = None, @@ -1334,18 +1380,39 @@ def generate( ) is_greedy_gen_mode = ( - (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode and not is_contrastive_search_gen_mode + (num_beams == 1) + and (num_beam_groups == 1) + and do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode ) is_sample_gen_mode = ( - (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode and not is_contrastive_search_gen_mode + (num_beams == 1) + and (num_beam_groups == 1) + and do_sample is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode ) is_beam_gen_mode = ( - (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode and not is_contrastive_search_gen_mode + (num_beams > 1) + and (num_beam_groups == 1) + and do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode ) is_beam_sample_gen_mode = ( - (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode and not is_contrastive_search_gen_mode + (num_beams > 1) + and (num_beam_groups == 1) + and do_sample is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_group_beam_gen_mode = ( + (num_beams > 1) + and (num_beam_groups > 1) + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode ) - is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode and not is_contrastive_search_gen_mode if num_beam_groups > num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") @@ -1679,9 +1746,9 @@ def contrastive_search( **model_kwargs, ) -> Union[GreedySearchOutput, torch.LongTensor]: r""" + Parameters: Generates sequences of token ids for models with a language modeling head using **contrastive search** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. top_k (`int`, *optional*, defaults to 1): @@ -1732,6 +1799,7 @@ def contrastive_search( ... StoppingCriteriaList, ... MaxLengthCriteria, ... ) + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token From 41e37a56854cc5d3b0470932c511c11950dafc99 Mon Sep 17 00:00:00 2001 From: gmftbyGMFTBY <18811371908@163.com> Date: Tue, 11 Oct 2022 10:41:00 +0800 Subject: [PATCH 03/17] update the quality of codes --- src/transformers/generation_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index ae2dbf59464d1..84158d620a0d0 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -105,11 +105,13 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when + `config.output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` + is passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. """ @@ -127,11 +129,13 @@ class ContrastiveSearchDecoderOnlyOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when + `config.output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is + passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. """ From 32e2a3070709e6d3620c82f9bbab9feeedfe1817 Mon Sep 17 00:00:00 2001 From: gmftbyGMFTBY <18811371908@163.com> Date: Wed, 12 Oct 2022 14:03:41 +0800 Subject: [PATCH 04/17] revise the docstring; make the generation_contrastive_search.py scripts; --- .../run_generation_contrastive_search.py | 8 - .../generation_contrastive_search.py | 178 ++++++++++++++++++ src/transformers/generation_utils.py | 159 +--------------- 3 files changed, 187 insertions(+), 158 deletions(-) create mode 100644 src/transformers/generation_contrastive_search.py diff --git a/examples/pytorch/text-generation/run_generation_contrastive_search.py b/examples/pytorch/text-generation/run_generation_contrastive_search.py index 3a9f7539df906..9595b920c6af2 100755 --- a/examples/pytorch/text-generation/run_generation_contrastive_search.py +++ b/examples/pytorch/text-generation/run_generation_contrastive_search.py @@ -101,14 +101,6 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text): language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ") model.config.lang_id = model.config.lang2id[language] - # kwargs["language"] = tokenizer.lang2id[language] - - # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers - # XLM masked-language modeling (MLM) models need masked token - # is_xlm_mlm = "mlm" in args.model_name_or_path - # if is_xlm_mlm: - # kwargs["mask_token_id"] = tokenizer.mask_token_id - return prompt_text diff --git a/src/transformers/generation_contrastive_search.py b/src/transformers/generation_contrastive_search.py new file mode 100644 index 0000000000000..286249a6eb85d --- /dev/null +++ b/src/transformers/generation_contrastive_search.py @@ -0,0 +1,178 @@ +# coding=utf-8 +# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +from torch import nn + +from .utils import logging + + +logger = logging.get_logger(__name__) + + +""" +This file contains the utils functions for the contrastive search, which will be called in `generation_utils` +""" + + +def ranking_fast( + context_hidden: torch.FloatTensor, + next_hidden: torch.FloatTensor, + next_top_k_probs: torch.FloatTensor, + alpha: float, + beam_width: int, +) -> Tuple[torch.FloatTensor]: + """ + context_hidden: bsz*beam x seqlen x embed_dim next_hidden: bsz*beam x 1 x embed_dim next_top_k_probs: bsz x beam + """ + _, context_len, embed_dim = context_hidden.size() + norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) + norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) + cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] + scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K] + next_top_k_probs = next_top_k_probs.view(-1) # [B*K] + scores = (1.0 - alpha) * next_top_k_probs - alpha * scores + scores = torch.stack(torch.split(scores, beam_width)) # [B, K] + selected_scores, selected_idx = scores.max(dim=-1) # [B] + return selected_scores, selected_idx + + +def ContrastiveDecodingOneStepFast( + model, + beam_width: int = 1, + penalty_alpha: float = 0.0, + past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, + last_hidden_states: torch.FloatTensor = None, + logit_for_next_step: torch.FloatTensor = None, + is_encoder_decoder: bool = False, + **model_inputs, +) -> Tuple: + """ + contrastive search first selects top-k candidates by the logit scores; then these candidate tokens are fed into the + language models to compute the degeneration penalty, which will be used to re-rank these candidate tokens. + """ + bsz, seqlen, embed_dim = last_hidden_states.size() + next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) + _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) + top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) + past_key_values = enlarge_past_key_values(past_key_values, beam_width) + + # build next attention mask + attention_mask = model_inputs["attention_mask"] # [B, S] + # decoder-only model need the full attention mask, not only the mask for the last token + if is_encoder_decoder is False: + attention_mask = torch.cat([attention_mask, attention_mask.new_ones((bsz, 1))], dim=-1) + attention_mask = attention_mask.unsqueeze(1).expand(-1, beam_width, -1).reshape(-1, attention_mask.size(-1)) + + # encoder-decoder model also contains the `encoder_outputs` + if is_encoder_decoder and "encoder_outputs" in model_inputs: + encoder_outputs = model_inputs["encoder_outputs"] + else: + encoder_outputs = None + next_model_inputs = model.prepare_inputs_for_generation( + top_k_ids.view(-1, 1), + past=past_key_values, + attention_mask=attention_mask, + use_cache=True, + encoder_outputs=encoder_outputs, + ) + # compute the candidate tokens by the language model and collects their hidden_states + output = model(output_hidden_states=True, **next_model_inputs) + past_key_values = output.past_key_values + logits = output.logits[:, -1, :] + # name is different for encoder-decoder and decoder-only models + if is_encoder_decoder: + next_hidden = output.decoder_hidden_states[-1] + full_hidden_states = output.decoder_hidden_states + else: + next_hidden = output.hidden_states[-1] + full_hidden_states = output.hidden_states + context_hidden = ( + last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz * beam_width, seqlen, embed_dim) + ) + + # compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the model confidence + # the scores and index of the selected tokens are returned + selected_scores, selected_idx = ranking_fast( + context_hidden, + next_hidden, + top_k_probs, + penalty_alpha, + beam_width, + ) + # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores (model confidence minus degeneration penalty); (6) decoder hidden_states + next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) + next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) + next_hidden = next_hidden[range(bsz), selected_idx, :] + last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) + + decoder_hidden_states = [] + for layer in full_hidden_states: + layer = torch.stack(torch.split(layer.squeeze(dim=1), beam_width)) + layer = layer[range(bsz), selected_idx, :] + decoder_hidden_states.append(layer) + + past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx) + logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] + return next_id.squeeze(dim=-1), past_key_values, last_hidden_states, logits, selected_scores, decoder_hidden_states + + +def enlarge_past_key_values( + past_key_values: Tuple[Tuple[torch.FloatTensor]], beam_width: int +) -> Tuple[Tuple[torch.FloatTensor]]: + """ + Copy and extend the past_key_values for the next step re-rank each item in `past_key_values` is the 4-dimension + matrix, whose shapre is [batch_size, num_head, seq_len, embed_dim] Suppose the size of the next token candidate + size is K, we need to obtain the new `past_key_values`, whose shape is [batch_size*K, num_head, seq_len, embed_dim] + """ + # from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz] + new_key_values = [] + for layer in past_key_values: + items = [] + # item is either the key or the value matrix + for item in layer: + bsz, num_head, seq_len, esz = item.size() + item = ( + item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz * beam_width, num_head, seq_len, esz) + ) # [bsz*beam, num_head, seq_len, esz] + items.append(item) + new_key_values.append(items) + return new_key_values + + +def select_past_key_values( + past_key_values: Tuple[Tuple[torch.FloatTensor]], beam_width: int, selected_idx: torch.FloatTensor +) -> Tuple[Tuple[torch.FloatTensor]]: + """ + Extract the `past_key_value` for the selected tokens, each item in `past_key_value` is the 4-dimension matrix, + whose shape is [batch_size*K, num_head, seq_len, embed_dim], where K is the number of the candidate tokens. We aim + to obtain the `past_key_value` of the selected next token, whose shape is [batch_size, num_head, seq_len, + embed_dim] + """ + new_key_values = [] + for layer in past_key_values: + items = [] + # item is either the key or the value matrix + for item in layer: + bsz_and_beam, num_head, seq_len, esz = item.size() + bsz = int(bsz_and_beam // beam_width) + item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz] + item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz] + items.append(item) + new_key_values.append(items) + return new_key_values diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 4fd9fa99cceef..c0132743c29b8 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -25,6 +25,7 @@ from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer +from .generation_contrastive_search import ContrastiveDecodingOneStepFast from .generation_logits_process import ( EncoderNoRepeatNGramLogitsProcessor, ExponentialDecayLengthPenalty, @@ -1016,6 +1017,8 @@ def generate( - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and `do_sample=False`. + - *contrastive search* by calling [`~generation_utils.GenerationMixin.contrastive_search`] if + `penalty_alpha>0.` and `top_k>1` - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and `do_sample=True`. - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and @@ -1061,6 +1064,8 @@ def generate( Number of beams for beam search. 1 means no beam search. temperature (`float`, *optional*, defaults to `model.config.temperature` or 1.0 if the config does not set any value): The value used to module the next token probabilities. + penalty_alpha (`float`, *optional*, defaults to `model.config.penalty_alpha` or None if the config does not set any value): + The values balance the model confidence and the degeneration penalty in contrastive search decoding. top_k (`int`, *optional*, defaults to `model.config.top_k` or 50 if the config does not set any value): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`, *optional*, defaults to `model.config.top_p` or 1.0 if the config does not set any value): @@ -1761,9 +1766,10 @@ def contrastive_search( **model_kwargs, ) -> Union[GreedySearchOutput, torch.LongTensor]: r""" - Parameters: Generates sequences of token ids for models with a language modeling head using **contrastive search** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. top_k (`int`, *optional*, defaults to 1): @@ -1798,12 +1804,14 @@ def contrastive_search( model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + Return: [`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. + Examples: ```python >>> from transformers import ( @@ -3776,152 +3784,3 @@ def top_k_top_p_filtering( ) return logits - - -# ========== utils for contrastive search decoding method ========= # -def ranking_fast( - context_hidden: torch.FloatTensor, - next_hidden: torch.FloatTensor, - next_top_k_probs: torch.FloatTensor, - alpha: float, - beam_width: int, -) -> Tuple[torch.FloatTensor]: - """ - context_hidden: bsz*beam x seqlen x embed_dim next_hidden: bsz*beam x 1 x embed_dim next_top_k_probs: bsz x beam - """ - _, context_len, embed_dim = context_hidden.size() - norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) - norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) - cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] - scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K] - next_top_k_probs = next_top_k_probs.view(-1) # [B*K] - scores = (1.0 - alpha) * next_top_k_probs - alpha * scores - scores = torch.stack(torch.split(scores, beam_width)) # [B, K] - selected_scores, selected_idx = scores.max(dim=-1) # [B] - return selected_scores, selected_idx - - -def ContrastiveDecodingOneStepFast( - model, - beam_width: int = 1, - penalty_alpha: float = 0.0, - past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, - last_hidden_states: torch.FloatTensor = None, - logit_for_next_step: torch.FloatTensor = None, - is_encoder_decoder: bool = False, - **model_inputs, -) -> Tuple: - """ - contrastive search first selects top-k candidates by the logit scores; then these candidate tokens are fed into the - language models to compute the degeneration penalty, which will be used to re-rank these candidate tokens. - """ - bsz, seqlen, embed_dim = last_hidden_states.size() - next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) - _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) - top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) - past_key_values = enlarge_past_key_values(past_key_values, beam_width) - - # build next attention mask - attention_mask = model_inputs["attention_mask"] # [B, S] - # decoder-only model need the full attention mask, not only the mask for the last token - if is_encoder_decoder is False: - attention_mask = torch.cat([attention_mask, attention_mask.new_ones((bsz, 1))], dim=-1) - attention_mask = attention_mask.unsqueeze(1).expand(-1, beam_width, -1).reshape(-1, attention_mask.size(-1)) - - # encoder-decoder model also contains the `encoder_outputs` - if is_encoder_decoder and "encoder_outputs" in model_inputs: - encoder_outputs = model_inputs["encoder_outputs"] - else: - encoder_outputs = None - next_model_inputs = model.prepare_inputs_for_generation( - top_k_ids.view(-1, 1), - past=past_key_values, - attention_mask=attention_mask, - use_cache=True, - encoder_outputs=encoder_outputs, - ) - # compute the candidate tokens by the language model and collects their hidden_states - output = model(output_hidden_states=True, **next_model_inputs) - past_key_values = output.past_key_values - logits = output.logits[:, -1, :] - # name is different for encoder-decoder and decoder-only models - if is_encoder_decoder: - next_hidden = output.decoder_hidden_states[-1] - full_hidden_states = output.decoder_hidden_states - else: - next_hidden = output.hidden_states[-1] - full_hidden_states = output.hidden_states - context_hidden = ( - last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz * beam_width, seqlen, embed_dim) - ) - - # compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the model confidence - # the scores and index of the selected tokens are returned - selected_scores, selected_idx = ranking_fast( - context_hidden, - next_hidden, - top_k_probs, - penalty_alpha, - beam_width, - ) - # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores (model confidence minus degeneration penalty); (6) decoder hidden_states - next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) - next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) - next_hidden = next_hidden[range(bsz), selected_idx, :] - last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) - - decoder_hidden_states = [] - for layer in full_hidden_states: - layer = torch.stack(torch.split(layer.squeeze(dim=1), beam_width)) - layer = layer[range(bsz), selected_idx, :] - decoder_hidden_states.append(layer) - - past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx) - logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] - return next_id.squeeze(dim=-1), past_key_values, last_hidden_states, logits, selected_scores, decoder_hidden_states - - -def enlarge_past_key_values( - past_key_values: Tuple[Tuple[torch.FloatTensor]], beam_width: int -) -> Tuple[Tuple[torch.FloatTensor]]: - """ - Copy and extend the past_key_values for the next step re-rank each item in `past_key_values` is the 4-dimension - matrix, whose shapre is [batch_size, num_head, seq_len, embed_dim] Suppose the size of the next token candidate - size is K, we need to obtain the new `past_key_values`, whose shape is [batch_size*K, num_head, seq_len, embed_dim] - """ - # from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz] - new_key_values = [] - for layer in past_key_values: - items = [] - # item is either the key or the value matrix - for item in layer: - bsz, num_head, seq_len, esz = item.size() - item = ( - item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz * beam_width, num_head, seq_len, esz) - ) # [bsz*beam, num_head, seq_len, esz] - items.append(item) - new_key_values.append(items) - return new_key_values - - -def select_past_key_values( - past_key_values: Tuple[Tuple[torch.FloatTensor]], beam_width: int, selected_idx: torch.FloatTensor -) -> Tuple[Tuple[torch.FloatTensor]]: - """ - Extract the `past_key_value` for the selected tokens, each item in `past_key_value` is the 4-dimension matrix, - whose shape is [batch_size*K, num_head, seq_len, embed_dim], where K is the number of the candidate tokens. We aim - to obtain the `past_key_value` of the selected next token, whose shape is [batch_size, num_head, seq_len, - embed_dim] - """ - new_key_values = [] - for layer in past_key_values: - items = [] - # item is either the key or the value matrix - for item in layer: - bsz_and_beam, num_head, seq_len, esz = item.size() - bsz = int(bsz_and_beam // beam_width) - item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz] - item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz] - items.append(item) - new_key_values.append(items) - return new_key_values From f3bfd87d3a7a71e2ea106ecd7400d5345dc065b0 Mon Sep 17 00:00:00 2001 From: gmftbyGMFTBY <18811371908@163.com> Date: Wed, 12 Oct 2022 22:35:07 +0800 Subject: [PATCH 05/17] revise the examples/pytorch/text-generation/run_generation_contrastive_search.py to the auto-APIs format --- .../run_generation_contrastive_search.py | 167 ++---------------- 1 file changed, 18 insertions(+), 149 deletions(-) diff --git a/examples/pytorch/text-generation/run_generation_contrastive_search.py b/examples/pytorch/text-generation/run_generation_contrastive_search.py index 9595b920c6af2..2b592f0219a07 100755 --- a/examples/pytorch/text-generation/run_generation_contrastive_search.py +++ b/examples/pytorch/text-generation/run_generation_contrastive_search.py @@ -14,7 +14,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet) +""" The examples of running contrastive search on the auto-APIs; + +Running this examples: +CUDA_VISIBLE_DEVICES=0 python run_generation_contrastive_search.py --model_name_or_path=gpt2-large --penalty_alpha=0.6 --k=4 --length=256 """ @@ -24,20 +27,7 @@ import numpy as np import torch -from transformers import ( - CTRLLMHeadModel, - CTRLTokenizer, - GPT2LMHeadModel, - GPT2Tokenizer, - OpenAIGPTLMHeadModel, - OpenAIGPTTokenizer, - TransfoXLLMHeadModel, - TransfoXLTokenizer, - XLMTokenizer, - XLMWithLMHeadModel, - XLNetLMHeadModel, - XLNetTokenizer, -) +from transformers import AutoModelForCausalLM, AutoTokenizer logging.basicConfig( @@ -47,22 +37,6 @@ ) logger = logging.getLogger(__name__) -MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop - -MODEL_CLASSES = { - "gpt2": (GPT2LMHeadModel, GPT2Tokenizer), - "ctrl": (CTRLLMHeadModel, CTRLTokenizer), - "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), - "xlnet": (XLNetLMHeadModel, XLNetTokenizer), - "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), - "xlm": (XLMWithLMHeadModel, XLMTokenizer), -} - -# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia -# in https://github.com/rusiaaman/XLNet-gen#methodology -# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e -PREFIX = """DeepMind Company is""" - def set_seed(args): np.random.seed(args.seed) @@ -71,90 +45,17 @@ def set_seed(args): torch.cuda.manual_seed_all(args.seed) -# -# Functions to prepare models' input -# - - -def prepare_ctrl_input(args, _, tokenizer, prompt_text): - if args.temperature > 0.7: - logger.info("CTRL typically works better with lower temperatures (and lower top_k).") - - encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False) - if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()): - logger.info("WARNING! You are not starting your generation from a control code so you won't get good results") - return prompt_text - - -def prepare_xlm_input(args, model, tokenizer, prompt_text): - # kwargs = {"language": None, "mask_token_id": None} - - # Set the language - use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb - if hasattr(model.config, "lang2id") and use_lang_emb: - available_languages = model.config.lang2id.keys() - if args.xlm_language in available_languages: - language = args.xlm_language - else: - language = None - while language not in available_languages: - language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ") - - model.config.lang_id = model.config.lang2id[language] - return prompt_text - - -def prepare_xlnet_input(args, _, tokenizer, prompt_text): - prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX - prompt_text = prefix + prompt_text - return prompt_text - - -def prepare_transfoxl_input(args, _, tokenizer, prompt_text): - prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX - prompt_text = prefix + prompt_text - return prompt_text - - -PREPROCESSING_FUNCTIONS = { - "ctrl": prepare_ctrl_input, - "xlm": prepare_xlm_input, - "xlnet": prepare_xlnet_input, - "transfo-xl": prepare_transfoxl_input, -} - - -def adjust_length_to_model(length, max_sequence_length): - if length < 0 and max_sequence_length > 0: - length = max_sequence_length - elif 0 < max_sequence_length < length: - length = max_sequence_length # No generation bigger than model size - elif length < 0: - length = MAX_LENGTH # avoid infinite loop - return length - - def main(): parser = argparse.ArgumentParser() - parser.add_argument( - "--model_type", - default=None, - type=str, - required=True, - help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), - ) parser.add_argument( "--model_name_or_path", default=None, type=str, required=True, - help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()), ) - parser.add_argument("--prompt", type=str, default="") parser.add_argument("--length", type=int, default=20) parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") - parser.add_argument( "--temperature", type=float, @@ -165,6 +66,7 @@ def main(): "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2" ) parser.add_argument("--k", type=int, default=0) + parser.add_argument("--penalty_alpha", type=float, default=0.0) parser.add_argument("--p", type=float, default=0.9) parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.") @@ -173,7 +75,6 @@ def main(): parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") - parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") parser.add_argument( "--fp16", action="store_true", @@ -189,75 +90,43 @@ def main(): set_seed(args) # Initialize the model and tokenizer - try: - args.model_type = args.model_type.lower() - model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - except KeyError: - raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)") + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) + model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path) - tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) - model = model_class.from_pretrained(args.model_name_or_path) + # tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path) + # model = OPTForCausalLM.from_pretrained(args.model_name_or_path) model.to(args.device) if args.fp16: model.half() - args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings) logger.info(args) - prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") - # Different models need different input formatting and/or extra arguments - requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys() - if requires_preprocessing: - prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) - preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) - - if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: - tokenizer_kwargs = {"add_space_before_punct_symbol": True} - else: - tokenizer_kwargs = {} - - encoded_prompt = tokenizer.encode( - preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs - ) - else: - prefix = args.prefix if args.prefix else args.padding_text - encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt") - encoded_prompt = encoded_prompt.to(args.device) - - if encoded_prompt.size()[-1] == 0: - input_ids = None - else: - input_ids = encoded_prompt + inputs = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False) + inputs = {key: value.to(args.device) for key, value in inputs.items()} output_sequences = model.generate( - input_ids=input_ids, - max_length=args.length + len(encoded_prompt[0]), - num_return_sequences=args.num_return_sequences, - penalty_alpha=0.6, - top_k=4, + **inputs, + max_length=args.length + len(inputs["input_ids"][0]), + penalty_alpha=args.penalty_alpha, + top_k=args.k, ) - # Remove the batch dimension when returning multiple sequences - if len(output_sequences.shape) > 2: - output_sequences.squeeze_() - generated_sequences = [] - for generated_sequence_idx, generated_sequence in enumerate(output_sequences): print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") generated_sequence = generated_sequence.tolist() # Decode text - text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) + text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, add_special_tokens=False) # Remove all text after the stop token text = text[: text.find(args.stop_token) if args.stop_token else None] # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing total_sequence = ( - prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] + prompt_text + text[len(tokenizer.decode(inputs["input_ids"][0], clean_up_tokenization_spaces=True)) :] ) generated_sequences.append(total_sequence) From ce26f9f4eee4ddcf85d59ddae4ec76278c691697 Mon Sep 17 00:00:00 2001 From: gmftbyGMFTBY <18811371908@163.com> Date: Thu, 13 Oct 2022 00:07:24 +0800 Subject: [PATCH 06/17] revise the necessary documents --- docs/source/en/internal/generation_utils.mdx | 3 ++ .../en/main_classes/text_generation.mdx | 1 + src/transformers/__init__.py | 2 ++ .../generation_contrastive_search.py | 19 ++++++++++++ src/transformers/generation_utils.py | 31 ++++++++++++------- src/transformers/utils/dummy_pt_objects.py | 7 +++++ 6 files changed, 51 insertions(+), 12 deletions(-) diff --git a/docs/source/en/internal/generation_utils.mdx b/docs/source/en/internal/generation_utils.mdx index bdb6c7c59ce32..8efb7ae2fd7d6 100644 --- a/docs/source/en/internal/generation_utils.mdx +++ b/docs/source/en/internal/generation_utils.mdx @@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License. This page lists all the utility functions used by [`~generation_utils.GenerationMixin.generate`], [`~generation_utils.GenerationMixin.greedy_search`], +[`~generation_utils.GenerationMixin.contrastive_search`], [`~generation_utils.GenerationMixin.sample`], [`~generation_utils.GenerationMixin.beam_search`], [`~generation_utils.GenerationMixin.beam_sample`], @@ -261,3 +262,5 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] top_k_top_p_filtering [[autodoc]] tf_top_k_top_p_filtering + +[[autodoc]] ContrastiveDecodingOneStepFast diff --git a/docs/source/en/main_classes/text_generation.mdx b/docs/source/en/main_classes/text_generation.mdx index 94deeeae89411..db23189ab99c6 100644 --- a/docs/source/en/main_classes/text_generation.mdx +++ b/docs/source/en/main_classes/text_generation.mdx @@ -26,6 +26,7 @@ Each framework has a generate method for auto-regressive text generation impleme - sample - beam_search - beam_sample + - contrastive_search - group_beam_search - constrained_beam_search diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 0e69839f0ec17..4d5c6bc596899 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -800,6 +800,7 @@ "PhrasalConstraint", ] _import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"] + _import_structure["generation_contrastive_search"] = ["ContrastiveDecodingOneStepFast"] _import_structure["generation_logits_process"] = [ "ForcedBOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor", @@ -3746,6 +3747,7 @@ PhrasalConstraint, ) from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer + from .generation_contrastive_search import ContrastiveDecodingOneStepFast from .generation_logits_process import ( ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, diff --git a/src/transformers/generation_contrastive_search.py b/src/transformers/generation_contrastive_search.py index 286249a6eb85d..1dc8bc1a5cf52 100644 --- a/src/transformers/generation_contrastive_search.py +++ b/src/transformers/generation_contrastive_search.py @@ -65,6 +65,25 @@ def ContrastiveDecodingOneStepFast( """ contrastive search first selects top-k candidates by the logit scores; then these candidate tokens are fed into the language models to compute the degeneration penalty, which will be used to re-rank these candidate tokens. + + Args: + model: the generation model + beam_width (`int`, *optional*, defauls to 1): + If > 1, top k candidate tokens are selected for re-ranking + penalty_alpha (`float`, *optional*, defaults to 0.0): + If > 0, the model confidence will minus the degeneration penalty that is weighted by the penalty_alpha + parameter + past_key_value (`Tuple[Tuple[torch.FloatTensor]]`, *optional*, defaults to None): + it saves the cached key-value results that computed by previous steps + last_hidden_states (`torch.FloatTensor`, *optional*, defaults to None): + the last_hidden_states generated by the previous step + logit_for_next_step (`torch.FloatTensor`, *optional*, defaults to None): + the logit_for_next_step generated by the previous step + is_encoder_decode (`bool`, *optional*, defaults to False): + if True, the model is an encoder-decode model, else the model is an decoder-only model, such as GPT2 and + OPT + + From: https://github.com/yxuansu/SimCTG/blob/main/simctg/utlisgpt.py """ bsz, seqlen, embed_dim = last_hidden_states.size() next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index c0132743c29b8..3cc2e410f6b0a 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -443,6 +443,8 @@ class GenerationMixin: The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for: - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and `do_sample=False`. + - *contrastive search* by calling [`~generation_utils.GenerationMixin.contrastive_search`] if `penalty_alpha>0` + and `top_k>1` - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and `do_sample=True`. - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and @@ -1767,7 +1769,9 @@ def contrastive_search( ) -> Union[GreedySearchOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **contrastive search** and can - be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. The utils functions + [`ContrastiveDecodingOneStepFast`] is used in this function, which supports the batch generation and the cached + mechanism for speeding up decoding Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -1806,10 +1810,11 @@ def contrastive_search( If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: - [`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`] - or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if + [`~generation_utils.ContrastiveSearchDecoderOnlyOutput`], + [`~generation_utils.ContrastiveSearchEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` + containing the generated tokens (default behaviour) or a + [`~generation_utils.ContrastiveSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_utils.ContrastiveSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. Examples: @@ -1823,16 +1828,18 @@ def contrastive_search( ... MaxLengthCriteria, ... ) - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2-large") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2-large") >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token >>> model.config.pad_token_id = model.config.eos_token_id - >>> input_prompt = "It might be possible to" - >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids - >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) - >>> outputs = model.contrastive_search(input_ids, stopping_criteria=stopping_criteria) + >>> input_prompt = "DeepMind Company is" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt") + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=32)]) + >>> outputs = model.contrastive_search( + ... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria + ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ["It might be possible to get a better understanding of the nature of the problem, but it's not"] + ["DeepMind Company is a leader in artificial intelligence (AI). We have a long history of working with companies such as Google, Facebook, Amazon, and Microsoft to"] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 5c9cf9cb43f64..5ad85fcd54df2 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -129,6 +129,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class ContrastiveDecodingOneStepFast(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject): _backends = ["torch"] From 68429ad90a70b70d2b4fc3540f86172e12e787d1 Mon Sep 17 00:00:00 2001 From: gmftbyGMFTBY <18811371908@163.com> Date: Thu, 13 Oct 2022 15:41:56 +0800 Subject: [PATCH 07/17] fix: revise the docstring of generation_contrastive_search.py --- .../generation_contrastive_search.py | 40 ++++++++++--------- src/transformers/generation_utils.py | 6 +-- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/src/transformers/generation_contrastive_search.py b/src/transformers/generation_contrastive_search.py index 1dc8bc1a5cf52..de424fbf64ff4 100644 --- a/src/transformers/generation_contrastive_search.py +++ b/src/transformers/generation_contrastive_search.py @@ -67,23 +67,22 @@ def ContrastiveDecodingOneStepFast( language models to compute the degeneration penalty, which will be used to re-rank these candidate tokens. Args: - model: the generation model + model ([`PretrainedModel`]): The decoder-only or encoder-decoder generation language models. beam_width (`int`, *optional*, defauls to 1): - If > 1, top k candidate tokens are selected for re-ranking + If > 1, top k candidate tokens are selected for re-ranking. penalty_alpha (`float`, *optional*, defaults to 0.0): If > 0, the model confidence will minus the degeneration penalty that is weighted by the penalty_alpha - parameter - past_key_value (`Tuple[Tuple[torch.FloatTensor]]`, *optional*, defaults to None): - it saves the cached key-value results that computed by previous steps - last_hidden_states (`torch.FloatTensor`, *optional*, defaults to None): - the last_hidden_states generated by the previous step - logit_for_next_step (`torch.FloatTensor`, *optional*, defaults to None): - the logit_for_next_step generated by the previous step - is_encoder_decode (`bool`, *optional*, defaults to False): - if True, the model is an encoder-decode model, else the model is an decoder-only model, such as GPT2 and - OPT - - From: https://github.com/yxuansu/SimCTG/blob/main/simctg/utlisgpt.py + parameter. + past_key_value (`Tuple[Tuple[torch.FloatTensor]]`): + It saves the cached key-value results that computed by previous steps. + last_hidden_states (`torch.FloatTensor`, *optional*): + The last_hidden_states generated by the previous step. + logit_for_next_step (`torch.FloatTensor`, *optional*): + The logit_for_next_step generated by the previous step. + is_encoder_decoder (`bool`, *optional*, defaults to `False`): + If `True`, the model is an encoder-decode model, otherwise the model is a decoder-only model. + + From: https://github.com/yxuansu/SimCTG/blob/main/contrastive_search_explanation/README.md """ bsz, seqlen, embed_dim = last_hidden_states.size() next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) @@ -92,11 +91,14 @@ def ContrastiveDecodingOneStepFast( past_key_values = enlarge_past_key_values(past_key_values, beam_width) # build next attention mask - attention_mask = model_inputs["attention_mask"] # [B, S] - # decoder-only model need the full attention mask, not only the mask for the last token - if is_encoder_decoder is False: - attention_mask = torch.cat([attention_mask, attention_mask.new_ones((bsz, 1))], dim=-1) - attention_mask = attention_mask.unsqueeze(1).expand(-1, beam_width, -1).reshape(-1, attention_mask.size(-1)) + if 'attention_mask' in model_inputs: + attention_mask = model_inputs["attention_mask"] # [B, S] + # decoder-only model need the full attention mask, not only the mask for the last token + if is_encoder_decoder is False: + attention_mask = torch.cat([attention_mask, attention_mask.new_ones((bsz, 1))], dim=-1) + attention_mask = attention_mask.unsqueeze(1).expand(-1, beam_width, -1).reshape(-1, attention_mask.size(-1)) + else: + attention_mask = None # encoder-decoder model also contains the `encoder_outputs` if is_encoder_decoder and "encoder_outputs" in model_inputs: diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 3cc2e410f6b0a..8ddbf461fea51 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1828,8 +1828,8 @@ def contrastive_search( ... MaxLengthCriteria, ... ) - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2-large") - >>> model = AutoModelForCausalLM.from_pretrained("gpt2-large") + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token >>> model.config.pad_token_id = model.config.eos_token_id >>> input_prompt = "DeepMind Company is" @@ -1839,7 +1839,7 @@ def contrastive_search( ... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ["DeepMind Company is a leader in artificial intelligence (AI). We have a long history of working with companies such as Google, Facebook, Amazon, and Microsoft to"] + ["DeepMind Company is a non-profit organization dedicated to advancing the interests of the people of the United States of America. We are committed to providing a safe,"] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() From e1f0db96e3124cf03098cd294901654ec818be5b Mon Sep 17 00:00:00 2001 From: GMFTBY <18811371908@163.com> Date: Thu, 13 Oct 2022 15:45:40 +0800 Subject: [PATCH 08/17] Fix the code indentation --- docs/source/en/main_classes/text_generation.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/main_classes/text_generation.mdx b/docs/source/en/main_classes/text_generation.mdx index db23189ab99c6..5c81dc9880186 100644 --- a/docs/source/en/main_classes/text_generation.mdx +++ b/docs/source/en/main_classes/text_generation.mdx @@ -26,7 +26,7 @@ Each framework has a generate method for auto-regressive text generation impleme - sample - beam_search - beam_sample - - contrastive_search + - contrastive_search - group_beam_search - constrained_beam_search From d2a5e027ec3a69b613cd0386db96db6eae320f5c Mon Sep 17 00:00:00 2001 From: gmftbyGMFTBY <18811371908@163.com> Date: Thu, 13 Oct 2022 16:56:14 +0800 Subject: [PATCH 09/17] fix: revise the nits and examples in contrastive_search docstring. --- docs/source/en/main_classes/text_generation.mdx | 2 +- src/transformers/generation_contrastive_search.py | 2 +- src/transformers/generation_utils.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/en/main_classes/text_generation.mdx b/docs/source/en/main_classes/text_generation.mdx index 5c81dc9880186..2fc7950cdbe1c 100644 --- a/docs/source/en/main_classes/text_generation.mdx +++ b/docs/source/en/main_classes/text_generation.mdx @@ -26,7 +26,7 @@ Each framework has a generate method for auto-regressive text generation impleme - sample - beam_search - beam_sample - - contrastive_search + - contrastive_search - group_beam_search - constrained_beam_search diff --git a/src/transformers/generation_contrastive_search.py b/src/transformers/generation_contrastive_search.py index de424fbf64ff4..f1dc07877c948 100644 --- a/src/transformers/generation_contrastive_search.py +++ b/src/transformers/generation_contrastive_search.py @@ -91,7 +91,7 @@ def ContrastiveDecodingOneStepFast( past_key_values = enlarge_past_key_values(past_key_values, beam_width) # build next attention mask - if 'attention_mask' in model_inputs: + if "attention_mask" in model_inputs: attention_mask = model_inputs["attention_mask"] # [B, S] # decoder-only model need the full attention mask, not only the mask for the last token if is_encoder_decoder is False: diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 8ddbf461fea51..b540f0726cbf4 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1828,18 +1828,18 @@ def contrastive_search( ... MaxLengthCriteria, ... ) - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + >>> model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token >>> model.config.pad_token_id = model.config.eos_token_id >>> input_prompt = "DeepMind Company is" >>> input_ids = tokenizer(input_prompt, return_tensors="pt") >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=32)]) >>> outputs = model.contrastive_search( - ... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria + ... **input_ids, penalty_alpha=0.6, top_k=5, stopping_criteria=stopping_criteria ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ["DeepMind Company is a non-profit organization dedicated to advancing the interests of the people of the United States of America. We are committed to providing a safe,"] + ["DeepMind Company is a leader in Artificial Intelligence (AI) and has been recognized by industry leaders as one of the fastest growing companies in the AI industry."] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() From d5d30b78185393982048d99534382d8bd0e9edf8 Mon Sep 17 00:00:00 2001 From: gmftbyGMFTBY <18811371908@163.com> Date: Fri, 14 Oct 2022 21:10:58 +0800 Subject: [PATCH 10/17] fix the copyright --- .../run_generation_contrastive_search.py | 7 +++---- src/transformers/generation_contrastive_search.py | 11 ++++------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/examples/pytorch/text-generation/run_generation_contrastive_search.py b/examples/pytorch/text-generation/run_generation_contrastive_search.py index 2b592f0219a07..117f063a6dd9a 100755 --- a/examples/pytorch/text-generation/run_generation_contrastive_search.py +++ b/examples/pytorch/text-generation/run_generation_contrastive_search.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # coding=utf-8 -# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,8 +15,8 @@ # limitations under the License. """ The examples of running contrastive search on the auto-APIs; -Running this examples: -CUDA_VISIBLE_DEVICES=0 python run_generation_contrastive_search.py --model_name_or_path=gpt2-large --penalty_alpha=0.6 --k=4 --length=256 +Running this example: +python run_generation_contrastive_search.py --model_name_or_path=gpt2-large --penalty_alpha=0.6 --k=4 --length=256 """ diff --git a/src/transformers/generation_contrastive_search.py b/src/transformers/generation_contrastive_search.py index f1dc07877c948..9648a91cd088d 100644 --- a/src/transformers/generation_contrastive_search.py +++ b/src/transformers/generation_contrastive_search.py @@ -1,6 +1,5 @@ # coding=utf-8 -# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +This file contains the utils functions for the contrastive search, which will be called in `generation_utils` +""" from typing import Tuple @@ -25,11 +27,6 @@ logger = logging.get_logger(__name__) -""" -This file contains the utils functions for the contrastive search, which will be called in `generation_utils` -""" - - def ranking_fast( context_hidden: torch.FloatTensor, next_hidden: torch.FloatTensor, From c344a0a51af5abc43ab67623b44a855208d14d12 Mon Sep 17 00:00:00 2001 From: gmftbyGMFTBY <18811371908@163.com> Date: Mon, 17 Oct 2022 23:36:24 +0800 Subject: [PATCH 11/17] delete generation_contrastive_search.py --- docs/source/en/internal/generation_utils.mdx | 2 - src/transformers/__init__.py | 2 - .../generation_contrastive_search.py | 196 ----------------- src/transformers/generation_utils.py | 199 ++++++++++++++---- src/transformers/utils/dummy_pt_objects.py | 7 - 5 files changed, 162 insertions(+), 244 deletions(-) delete mode 100644 src/transformers/generation_contrastive_search.py diff --git a/docs/source/en/internal/generation_utils.mdx b/docs/source/en/internal/generation_utils.mdx index 8efb7ae2fd7d6..31db57740eca5 100644 --- a/docs/source/en/internal/generation_utils.mdx +++ b/docs/source/en/internal/generation_utils.mdx @@ -262,5 +262,3 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] top_k_top_p_filtering [[autodoc]] tf_top_k_top_p_filtering - -[[autodoc]] ContrastiveDecodingOneStepFast diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index df1d803ae9243..263a9a27cc22c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -802,7 +802,6 @@ "PhrasalConstraint", ] _import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"] - _import_structure["generation_contrastive_search"] = ["ContrastiveDecodingOneStepFast"] _import_structure["generation_logits_process"] = [ "ForcedBOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor", @@ -3751,7 +3750,6 @@ PhrasalConstraint, ) from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer - from .generation_contrastive_search import ContrastiveDecodingOneStepFast from .generation_logits_process import ( ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, diff --git a/src/transformers/generation_contrastive_search.py b/src/transformers/generation_contrastive_search.py deleted file mode 100644 index 9648a91cd088d..0000000000000 --- a/src/transformers/generation_contrastive_search.py +++ /dev/null @@ -1,196 +0,0 @@ -# coding=utf-8 -# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This file contains the utils functions for the contrastive search, which will be called in `generation_utils` -""" - -from typing import Tuple - -import torch -from torch import nn - -from .utils import logging - - -logger = logging.get_logger(__name__) - - -def ranking_fast( - context_hidden: torch.FloatTensor, - next_hidden: torch.FloatTensor, - next_top_k_probs: torch.FloatTensor, - alpha: float, - beam_width: int, -) -> Tuple[torch.FloatTensor]: - """ - context_hidden: bsz*beam x seqlen x embed_dim next_hidden: bsz*beam x 1 x embed_dim next_top_k_probs: bsz x beam - """ - _, context_len, embed_dim = context_hidden.size() - norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) - norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) - cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] - scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K] - next_top_k_probs = next_top_k_probs.view(-1) # [B*K] - scores = (1.0 - alpha) * next_top_k_probs - alpha * scores - scores = torch.stack(torch.split(scores, beam_width)) # [B, K] - selected_scores, selected_idx = scores.max(dim=-1) # [B] - return selected_scores, selected_idx - - -def ContrastiveDecodingOneStepFast( - model, - beam_width: int = 1, - penalty_alpha: float = 0.0, - past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, - last_hidden_states: torch.FloatTensor = None, - logit_for_next_step: torch.FloatTensor = None, - is_encoder_decoder: bool = False, - **model_inputs, -) -> Tuple: - """ - contrastive search first selects top-k candidates by the logit scores; then these candidate tokens are fed into the - language models to compute the degeneration penalty, which will be used to re-rank these candidate tokens. - - Args: - model ([`PretrainedModel`]): The decoder-only or encoder-decoder generation language models. - beam_width (`int`, *optional*, defauls to 1): - If > 1, top k candidate tokens are selected for re-ranking. - penalty_alpha (`float`, *optional*, defaults to 0.0): - If > 0, the model confidence will minus the degeneration penalty that is weighted by the penalty_alpha - parameter. - past_key_value (`Tuple[Tuple[torch.FloatTensor]]`): - It saves the cached key-value results that computed by previous steps. - last_hidden_states (`torch.FloatTensor`, *optional*): - The last_hidden_states generated by the previous step. - logit_for_next_step (`torch.FloatTensor`, *optional*): - The logit_for_next_step generated by the previous step. - is_encoder_decoder (`bool`, *optional*, defaults to `False`): - If `True`, the model is an encoder-decode model, otherwise the model is a decoder-only model. - - From: https://github.com/yxuansu/SimCTG/blob/main/contrastive_search_explanation/README.md - """ - bsz, seqlen, embed_dim = last_hidden_states.size() - next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) - _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) - top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) - past_key_values = enlarge_past_key_values(past_key_values, beam_width) - - # build next attention mask - if "attention_mask" in model_inputs: - attention_mask = model_inputs["attention_mask"] # [B, S] - # decoder-only model need the full attention mask, not only the mask for the last token - if is_encoder_decoder is False: - attention_mask = torch.cat([attention_mask, attention_mask.new_ones((bsz, 1))], dim=-1) - attention_mask = attention_mask.unsqueeze(1).expand(-1, beam_width, -1).reshape(-1, attention_mask.size(-1)) - else: - attention_mask = None - - # encoder-decoder model also contains the `encoder_outputs` - if is_encoder_decoder and "encoder_outputs" in model_inputs: - encoder_outputs = model_inputs["encoder_outputs"] - else: - encoder_outputs = None - next_model_inputs = model.prepare_inputs_for_generation( - top_k_ids.view(-1, 1), - past=past_key_values, - attention_mask=attention_mask, - use_cache=True, - encoder_outputs=encoder_outputs, - ) - # compute the candidate tokens by the language model and collects their hidden_states - output = model(output_hidden_states=True, **next_model_inputs) - past_key_values = output.past_key_values - logits = output.logits[:, -1, :] - # name is different for encoder-decoder and decoder-only models - if is_encoder_decoder: - next_hidden = output.decoder_hidden_states[-1] - full_hidden_states = output.decoder_hidden_states - else: - next_hidden = output.hidden_states[-1] - full_hidden_states = output.hidden_states - context_hidden = ( - last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz * beam_width, seqlen, embed_dim) - ) - - # compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the model confidence - # the scores and index of the selected tokens are returned - selected_scores, selected_idx = ranking_fast( - context_hidden, - next_hidden, - top_k_probs, - penalty_alpha, - beam_width, - ) - # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores (model confidence minus degeneration penalty); (6) decoder hidden_states - next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) - next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) - next_hidden = next_hidden[range(bsz), selected_idx, :] - last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) - - decoder_hidden_states = [] - for layer in full_hidden_states: - layer = torch.stack(torch.split(layer.squeeze(dim=1), beam_width)) - layer = layer[range(bsz), selected_idx, :] - decoder_hidden_states.append(layer) - - past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx) - logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] - return next_id.squeeze(dim=-1), past_key_values, last_hidden_states, logits, selected_scores, decoder_hidden_states - - -def enlarge_past_key_values( - past_key_values: Tuple[Tuple[torch.FloatTensor]], beam_width: int -) -> Tuple[Tuple[torch.FloatTensor]]: - """ - Copy and extend the past_key_values for the next step re-rank each item in `past_key_values` is the 4-dimension - matrix, whose shapre is [batch_size, num_head, seq_len, embed_dim] Suppose the size of the next token candidate - size is K, we need to obtain the new `past_key_values`, whose shape is [batch_size*K, num_head, seq_len, embed_dim] - """ - # from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz] - new_key_values = [] - for layer in past_key_values: - items = [] - # item is either the key or the value matrix - for item in layer: - bsz, num_head, seq_len, esz = item.size() - item = ( - item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz * beam_width, num_head, seq_len, esz) - ) # [bsz*beam, num_head, seq_len, esz] - items.append(item) - new_key_values.append(items) - return new_key_values - - -def select_past_key_values( - past_key_values: Tuple[Tuple[torch.FloatTensor]], beam_width: int, selected_idx: torch.FloatTensor -) -> Tuple[Tuple[torch.FloatTensor]]: - """ - Extract the `past_key_value` for the selected tokens, each item in `past_key_value` is the 4-dimension matrix, - whose shape is [batch_size*K, num_head, seq_len, embed_dim], where K is the number of the candidate tokens. We aim - to obtain the `past_key_value` of the selected next token, whose shape is [batch_size, num_head, seq_len, - embed_dim] - """ - new_key_values = [] - for layer in past_key_values: - items = [] - # item is either the key or the value matrix - for item in layer: - bsz_and_beam, num_head, seq_len, esz = item.size() - bsz = int(bsz_and_beam // beam_width) - item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz] - item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz] - items.append(item) - new_key_values.append(items) - return new_key_values diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index b540f0726cbf4..14e088dc2d0c3 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -25,7 +25,6 @@ from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer -from .generation_contrastive_search import ContrastiveDecodingOneStepFast from .generation_logits_process import ( EncoderNoRepeatNGramLogitsProcessor, ExponentialDecayLengthPenalty, @@ -1497,14 +1496,25 @@ def generate( if num_return_sequences > 1: raise ValueError( - f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." + f"num_return_sequences has to be 1, but is {num_return_sequences} when doing contrastive search." ) + # 10. prepare logits warper: get the TopKLogitsWarper for contrastive_search + logits_warper = self._get_logits_warper( + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + temperature=temperature, + num_beams=num_beams, + renormalize_logits=renormalize_logits, + ) + return self.contrastive_search( input_ids, top_k=top_k, penalty_alpha=penalty_alpha, logits_processor=logits_processor, + logits_warper=logits_warper, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, @@ -1756,6 +1766,7 @@ def contrastive_search( top_k: Optional[int] = 1, penalty_alpha: Optional[float] = 0, logits_processor: Optional[LogitsProcessorList] = None, + logits_warper: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, @@ -1826,6 +1837,7 @@ def contrastive_search( ... MinLengthLogitsProcessor, ... StoppingCriteriaList, ... MaxLengthCriteria, + ... ContrastiveSearchRankingLogitsProcessor, ... ) >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") @@ -1835,8 +1847,17 @@ def contrastive_search( >>> input_prompt = "DeepMind Company is" >>> input_ids = tokenizer(input_prompt, return_tensors="pt") >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=32)]) + >>> logits_processor = LogitsProcessorList( + ... [ + ... ContrastiveSearchRankingLogitsProcessor(penalty_alpha=0.6, beam_width=5), + ... ] + ... ) >>> outputs = model.contrastive_search( - ... **input_ids, penalty_alpha=0.6, top_k=5, stopping_criteria=stopping_criteria + ... **input_ids, + ... penalty_alpha=0.6, + ... top_k=5, + ... stopping_criteria=stopping_criteria, + ... logits_processor=logits_processor, ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ["DeepMind Company is a leader in Artificial Intelligence (AI) and has been recognized by industry leaders as one of the fastest growing companies in the AI industry."] @@ -1871,24 +1892,7 @@ def contrastive_search( this_peer_finished = False # used by synced_gpus only - # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save the `encoder_outputs` - model_kwargs["use_cache"] = True - model_kwargs["past_key_values"] = None - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - output = self(**model_inputs, output_hidden_states=True, output_attentions=True) - - # past_key_values is activated for fast decoding - past_key_values = output.past_key_values - model_inputs["past_key_values"] = past_key_values - - # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with previous tokens) - if self.config.is_encoder_decoder: - last_hidden_states = output.decoder_hidden_states[-1] - else: - last_hidden_states = output.hidden_states[-1] - # next logit for contrastive search to select top-k candidate tokens - logit_for_next_step = output.logits[:, -1, :] - + step_counter = 0 while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. @@ -1900,23 +1904,121 @@ def contrastive_search( if this_peer_finished_flag.item() == 0.0: break + # if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values; (2) last_hidden_states; (3) logit_for_next_step + if step_counter == 0: + # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save the `encoder_outputs` + model_kwargs["use_cache"] = True + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + output = self(**model_inputs, output_hidden_states=True, output_attentions=True) + # past_key_values is activated for fast decoding + past_key_values = output.past_key_values + # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with previous tokens) + if self.config.is_encoder_decoder: + last_hidden_states = output.decoder_hidden_states[-1] + else: + last_hidden_states = output.hidden_states[-1] + # next logit for contrastive search to select top-k candidate tokens + logit_for_next_step = output.logits[:, -1, :] + + # contrastive_search main logic start: # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by degeneration penalty - ( - next_tokens, - past_key_values, - last_hidden_states, - logit_for_next_step, - selected_scores, - decoder_hidden_states_one_step, - ) = ContrastiveDecodingOneStepFast( - self, - beam_width=top_k, - penalty_alpha=penalty_alpha, - last_hidden_states=last_hidden_states, - logit_for_next_step=logit_for_next_step, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_inputs, - ) + bsz, seqlen, embed_dim = last_hidden_states.size() + + # logits processor: empty logits processor + logit_for_next_step = logits_processor(input_ids, logit_for_next_step) + next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) + # logit warper: applying after the softmax, which is consistent with the logic in the paper + logit_for_next_step = logits_warper(input_ids, next_probs) + + _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=top_k) + top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) + + # enlarge the past_key_values + new_key_values = [] + for layer in past_key_values: + items = [] + # item is either the key or the value matrix + for item in layer: + bsz, num_head, seq_len, esz = item.size() + item = ( + item.unsqueeze(1).expand(-1, top_k, -1, -1, -1).reshape(bsz * top_k, num_head, seq_len, esz) + ) # [bsz*beam, num_head, seq_len, esz] + items.append(item) + new_key_values.append(items) + past_key_values = new_key_values + + # build next attention mask + if "attention_mask" in model_inputs: + attention_mask = model_inputs["attention_mask"] # [B, S] + # decoder-only model need the full attention mask, not only the mask for the last token + if self.config.is_encoder_decoder is False: + attention_mask = torch.cat([attention_mask, attention_mask.new_ones((bsz, 1))], dim=-1) + attention_mask = attention_mask.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, attention_mask.size(-1)) + else: + attention_mask = None + + # encoder-decoder model also contains the `encoder_outputs` + if self.config.is_encoder_decoder and "encoder_outputs" in model_inputs: + encoder_outputs = model_inputs["encoder_outputs"] + else: + encoder_outputs = None + next_model_inputs = self.prepare_inputs_for_generation( + top_k_ids.view(-1, 1), + past=past_key_values, + attention_mask=attention_mask, + use_cache=True, + encoder_outputs=encoder_outputs, + ) + # compute the candidate tokens by the language model and collects their hidden_states + output = self(output_hidden_states=True, **next_model_inputs) + past_key_values = output.past_key_values + logits = output.logits[:, -1, :] + # name is different for encoder-decoder and decoder-only models + if self.config.is_encoder_decoder: + next_hidden = output.decoder_hidden_states[-1] + full_hidden_states = output.decoder_hidden_states + else: + next_hidden = output.hidden_states[-1] + full_hidden_states = output.hidden_states + context_hidden = ( + last_hidden_states.unsqueeze(1).expand(-1, top_k, -1, -1).reshape(bsz * top_k, seqlen, embed_dim) + ) + + # compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the model confidence + # the scores and index of the selected tokens are returned + selected_scores, selected_idx = ranking_fast( + context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k + ) + + # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores (model confidence minus degeneration penalty); (6) decoder hidden_states + next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx] + next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) + next_hidden = next_hidden[range(bsz), selected_idx, :] + last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) + + decoder_hidden_states_one_step = [] + for layer in full_hidden_states: + layer = torch.stack(torch.split(layer.squeeze(dim=1), top_k)) + layer = layer[range(bsz), selected_idx, :] + decoder_hidden_states_one_step.append(layer) + + # select the past_key_value + new_key_values = [] + for layer in past_key_values: + items = [] + # item is either the key or the value matrix + for item in layer: + bsz_and_beam, num_head, seq_len, esz = item.size() + bsz = int(bsz_and_beam // top_k) + item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz] + item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz] + items.append(item) + new_key_values.append(items) + past_key_values = new_key_values + + logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(bsz), selected_idx, :] + # contrastive_search main logic end:: + # after running the above codes, we update following parameters: next_tokens, past_key_values, logit_for_next_step, selected_score, decoder_hidden_states_one_step if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need @@ -1963,6 +2065,7 @@ def contrastive_search( # prepare model inputs model_kwargs["past_key_values"] = past_key_values model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + step_counter += 1 if return_dict_in_generate: if self.config.is_encoder_decoder: @@ -3791,3 +3894,25 @@ def top_k_top_p_filtering( ) return logits + + +def ranking_fast( + context_hidden: torch.FloatTensor, + next_hidden: torch.FloatTensor, + next_top_k_probs: torch.FloatTensor, + alpha: float, + beam_width: int, +) -> Tuple[torch.FloatTensor]: + """ + context_hidden: bsz*beam x seqlen x embed_dim next_hidden: bsz*beam x 1 x embed_dim next_top_k_probs: bsz x beam + """ + _, context_len, embed_dim = context_hidden.size() + norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) + norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) + cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] + scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K] + next_top_k_probs = next_top_k_probs.view(-1) # [B*K] + scores = (1.0 - alpha) * next_top_k_probs - alpha * scores + scores = torch.stack(torch.split(scores, beam_width)) # [B, K] + selected_scores, selected_idx = scores.max(dim=-1) # [B] + return selected_scores, selected_idx diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 5ad85fcd54df2..5c9cf9cb43f64 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -129,13 +129,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class ContrastiveDecodingOneStepFast(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject): _backends = ["torch"] From 7af4cbb5ac86476ff1f5f9dbd38554468165bc06 Mon Sep 17 00:00:00 2001 From: gmftbyGMFTBY <18811371908@163.com> Date: Tue, 18 Oct 2022 10:03:17 +0800 Subject: [PATCH 12/17] revise the logic in contrastive_search --- src/transformers/generation_utils.py | 59 +++++++++------------------- 1 file changed, 18 insertions(+), 41 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 2216d6ff68634..52a5cc81610bf 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1500,22 +1500,11 @@ def generate( f"num_return_sequences has to be 1, but is {num_return_sequences} when doing contrastive search." ) - # 10. prepare logits warper: get the TopKLogitsWarper for contrastive_search - logits_warper = self._get_logits_warper( - top_k=top_k, - top_p=top_p, - typical_p=typical_p, - temperature=temperature, - num_beams=num_beams, - renormalize_logits=renormalize_logits, - ) - return self.contrastive_search( input_ids, top_k=top_k, penalty_alpha=penalty_alpha, logits_processor=logits_processor, - logits_warper=logits_warper, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, @@ -1781,9 +1770,7 @@ def contrastive_search( ) -> Union[GreedySearchOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **contrastive search** and can - be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. The utils functions - [`ContrastiveDecodingOneStepFast`] is used in this function, which supports the batch generation and the cached - mechanism for speeding up decoding + be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -1834,37 +1821,27 @@ def contrastive_search( >>> from transformers import ( ... AutoTokenizer, ... AutoModelForCausalLM, - ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... StoppingCriteriaList, ... MaxLengthCriteria, - ... ContrastiveSearchRankingLogitsProcessor, ... ) - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") - >>> model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2-large") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2-large") >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token >>> model.config.pad_token_id = model.config.eos_token_id >>> input_prompt = "DeepMind Company is" >>> input_ids = tokenizer(input_prompt, return_tensors="pt") - >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=32)]) - >>> logits_processor = LogitsProcessorList( - ... [ - ... ContrastiveSearchRankingLogitsProcessor(penalty_alpha=0.6, beam_width=5), - ... ] - ... ) + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=256)]) >>> outputs = model.contrastive_search( - ... **input_ids, - ... penalty_alpha=0.6, - ... top_k=5, - ... stopping_criteria=stopping_criteria, - ... logits_processor=logits_processor, + ... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ["DeepMind Company is a leader in Artificial Intelligence (AI) and has been recognized by industry leaders as one of the fastest growing companies in the AI industry."] + ["DeepMind Company is a leader in artificial intelligence (AI). We have a long history of working with companies such as Google, Facebook, Amazon, and Microsoft to build products that improve people\'s lives, and today we are excited to announce that DeepMind\'s AlphaGo program has won the game of Go, becoming the first program to defeat a professional Go player.\n\nThe victory is a testament to the power of deep learning, and to the incredible work of our research team, which has been at the forefront of AI research for the past five years. AlphaGo is one of the most advanced Go programs ever created, and its performance is an important step towards the goal of human-level AI.\n\n"This is the culmination of a decade of hard work," said Andy Ng, co-founder and CTO of DeepMind. "We are thrilled to have achieved this milestone and look forward to continuing to develop AI that can be used in a wide range of applications and to help people live better lives."\n\nDeepMind\'s work on Go began in 2010, when it began to train a neural network to play Go using millions of games played by top Go players around the world. Since then, the team has refined the algorithm, adding more and more layers of reinforcement"] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( @@ -1905,11 +1882,13 @@ def contrastive_search( if this_peer_finished_flag.item() == 0.0: break + # prepare inputs + model_kwargs["use_cache"] = True + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values; (2) last_hidden_states; (3) logit_for_next_step if step_counter == 0: # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save the `encoder_outputs` - model_kwargs["use_cache"] = True - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) output = self(**model_inputs, output_hidden_states=True, output_attentions=True) # past_key_values is activated for fast decoding past_key_values = output.past_key_values @@ -1925,11 +1904,10 @@ def contrastive_search( # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by degeneration penalty bsz, seqlen, embed_dim = last_hidden_states.size() - # logits processor: empty logits processor + # logits processor logit_for_next_step = logits_processor(input_ids, logit_for_next_step) + logit_for_next_step = logits_warper(input_ids, logit_for_next_step) next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) - # logit warper: applying after the softmax, which is consistent with the logic in the paper - logit_for_next_step = logits_warper(input_ids, next_probs) _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=top_k) top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) @@ -2065,7 +2043,6 @@ def contrastive_search( # prepare model inputs model_kwargs["past_key_values"] = past_key_values - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) step_counter += 1 if return_dict_in_generate: @@ -3911,9 +3888,9 @@ def ranking_fast( norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] - scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K] + degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K] next_top_k_probs = next_top_k_probs.view(-1) # [B*K] - scores = (1.0 - alpha) * next_top_k_probs - alpha * scores - scores = torch.stack(torch.split(scores, beam_width)) # [B, K] - selected_scores, selected_idx = scores.max(dim=-1) # [B] - return selected_scores, selected_idx + contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty + contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] + selected_scores, selected_idx = contrastive_score.max(dim=-1) # [B] + return torch.log(selected_scores), selected_idx From 183d7cc379171f56c3f17adbba52fb482df4f7b3 Mon Sep 17 00:00:00 2001 From: gmftbyGMFTBY <18811371908@163.com> Date: Tue, 18 Oct 2022 16:26:45 +0800 Subject: [PATCH 13/17] update the intergration test and the docstring --- src/transformers/generation_utils.py | 8 ++++---- tests/generation/test_generation_utils.py | 23 +++++++++++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 52a5cc81610bf..ee6a3c7f98125 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1826,18 +1826,18 @@ def contrastive_search( ... MaxLengthCriteria, ... ) - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2-large") - >>> model = AutoModelForCausalLM.from_pretrained("gpt2-large") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + >>> model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token >>> model.config.pad_token_id = model.config.eos_token_id >>> input_prompt = "DeepMind Company is" >>> input_ids = tokenizer(input_prompt, return_tensors="pt") - >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=256)]) + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=64)]) >>> outputs = model.contrastive_search( ... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ["DeepMind Company is a leader in artificial intelligence (AI). We have a long history of working with companies such as Google, Facebook, Amazon, and Microsoft to build products that improve people\'s lives, and today we are excited to announce that DeepMind\'s AlphaGo program has won the game of Go, becoming the first program to defeat a professional Go player.\n\nThe victory is a testament to the power of deep learning, and to the incredible work of our research team, which has been at the forefront of AI research for the past five years. AlphaGo is one of the most advanced Go programs ever created, and its performance is an important step towards the goal of human-level AI.\n\n"This is the culmination of a decade of hard work," said Andy Ng, co-founder and CTO of DeepMind. "We are thrilled to have achieved this milestone and look forward to continuing to develop AI that can be used in a wide range of applications and to help people live better lives."\n\nDeepMind\'s work on Go began in 2010, when it began to train a neural network to play Go using millions of games played by top Go players around the world. Since then, the team has refined the algorithm, adding more and more layers of reinforcement"] + ["DeepMind Company is a company that focuses on the development and commercialization of artificial intelligence (AI). DeepMind’s mission is to help people understand and solve problems that are difficult to solve in the world today.\n\nIn this post, we talk about the benefits of deep learning in business and how it"] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index f48cfff83cb85..069a5a8f668bd 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -1693,6 +1693,29 @@ def test_diverse_beam_search(self): ], ) + def test_contrastive_search(self): + article = """DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based""" + + gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2-large") + gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(torch_device) + input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + outputs = gpt2_model.generate( + input_ids, + penalty_alpha=0.6, + top_k=4, + max_length=256 + ) + + generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + '''DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as Google Now, which helps users find the information they\'re looking for on the web. But the company is not the only one to collect data on its users. Facebook, for example, has its own facial recognition technology, as well as a database of millions of photos that it uses to personalize its News Feed.\n\nFacebook\'s use of data is a hot topic in the tech industry, with privacy advocates concerned about the company\'s ability to keep users\' information private. In a blog post last year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, but said in a statement to The Associated Press that''' + ] + ) + def test_max_length_backward_compat_greedy(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") From 65a1ebd7e846bfa7b27b2bfb8106dd19171ef71d Mon Sep 17 00:00:00 2001 From: gmftbyGMFTBY <18811371908@163.com> Date: Tue, 18 Oct 2022 16:29:21 +0800 Subject: [PATCH 14/17] run the tests over --- tests/generation/test_generation_utils.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index 069a5a8f668bd..7e8a0417b58fd 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -1699,21 +1699,16 @@ def test_contrastive_search(self): gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2-large") gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(torch_device) input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - - outputs = gpt2_model.generate( - input_ids, - penalty_alpha=0.6, - top_k=4, - max_length=256 - ) + + outputs = gpt2_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256) generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True) self.assertListEqual( generated_text, [ - '''DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as Google Now, which helps users find the information they\'re looking for on the web. But the company is not the only one to collect data on its users. Facebook, for example, has its own facial recognition technology, as well as a database of millions of photos that it uses to personalize its News Feed.\n\nFacebook\'s use of data is a hot topic in the tech industry, with privacy advocates concerned about the company\'s ability to keep users\' information private. In a blog post last year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, but said in a statement to The Associated Press that''' - ] + """DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as Google Now, which helps users find the information they\'re looking for on the web. But the company is not the only one to collect data on its users. Facebook, for example, has its own facial recognition technology, as well as a database of millions of photos that it uses to personalize its News Feed.\n\nFacebook\'s use of data is a hot topic in the tech industry, with privacy advocates concerned about the company\'s ability to keep users\' information private. In a blog post last year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, but said in a statement to The Associated Press that""" + ], ) def test_max_length_backward_compat_greedy(self): From e11d342235574a3bb13f0a65ff0c9427ffc5a621 Mon Sep 17 00:00:00 2001 From: gmftbyGMFTBY <18811371908@163.com> Date: Tue, 18 Oct 2022 17:35:29 +0800 Subject: [PATCH 15/17] add the slow decorate to the contrastive_search intergrate test --- tests/generation/test_generation_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index 7e8a0417b58fd..53f4c38a23d3a 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -1693,6 +1693,7 @@ def test_diverse_beam_search(self): ], ) + @slow def test_contrastive_search(self): article = """DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based""" From 2aa768cbc5d12c850f50b4048e9f4be52a78dbcd Mon Sep 17 00:00:00 2001 From: gmftbyGMFTBY <18811371908@163.com> Date: Wed, 19 Oct 2022 13:28:16 +0800 Subject: [PATCH 16/17] add more test --- src/transformers/generation_utils.py | 21 +- tests/generation/test_generation_utils.py | 238 +++++++++++++++++++++- 2 files changed, 246 insertions(+), 13 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index ee6a3c7f98125..f813bffc5a088 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1758,7 +1758,6 @@ def contrastive_search( logits_processor: Optional[LogitsProcessorList] = None, logits_warper: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, @@ -1785,9 +1784,6 @@ def contrastive_search( stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`int`, *optional*): @@ -1843,13 +1839,6 @@ def contrastive_search( logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - warnings.warn( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", - UserWarning, - ) - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores @@ -1890,8 +1879,12 @@ def contrastive_search( if step_counter == 0: # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save the `encoder_outputs` output = self(**model_inputs, output_hidden_states=True, output_attentions=True) + # past_key_values is activated for fast decoding + if "past_key_values" not in output: + raise ValueError(f"self.__class__ cannot return `past_key_values` and can therefore **not** be used for contrastive search.") past_key_values = output.past_key_values + # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with previous tokens) if self.config.is_encoder_decoder: last_hidden_states = output.decoder_hidden_states[-1] @@ -1920,7 +1913,7 @@ def contrastive_search( for item in layer: bsz, num_head, seq_len, esz = item.size() item = ( - item.unsqueeze(1).expand(-1, top_k, -1, -1, -1).reshape(bsz * top_k, num_head, seq_len, esz) + item.unsqueeze(1).expand(-1, top_k, -1, -1, -1).reshape(bsz * top_k, num_head, seq_len, esz).contiguous() ) # [bsz*beam, num_head, seq_len, esz] items.append(item) new_key_values.append(items) @@ -1950,7 +1943,11 @@ def contrastive_search( ) # compute the candidate tokens by the language model and collects their hidden_states output = self(output_hidden_states=True, **next_model_inputs) + + if "past_key_values" not in output: + raise ValueError(f"self.__class__ cannot return `past_key_values` and can therefore **not** be used for contrastive search.") past_key_values = output.past_key_values + logits = output.logits[:, -1, :] # name is different for encoder-decoder and decoder-only models if self.config.is_encoder_decoder: diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index 53f4c38a23d3a..c623c7502b00d 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -29,6 +29,9 @@ from transformers import ( AutoModelForSeq2SeqLM, AutoTokenizer, + AutoModelForCausalLM, + OPTForCausalLM, + T5ForConditionalGeneration, BartForConditionalGeneration, BartTokenizer, GPT2LMHeadModel, @@ -1694,7 +1697,118 @@ def test_diverse_beam_search(self): ) @slow - def test_contrastive_search(self): + def test_contrastive_search_bart(self): + article = """ New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. +A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. +Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. +In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. +Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the +2010 marriage license application, according to court documents. +Prosecutors said the marriages were part of an immigration scam. +On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. +After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective +Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. +All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. +Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. +Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. +The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s +Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. +Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. +If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18. +""" + bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") + bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device) + input_ids = bart_tokenizer(article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt").input_ids.to(torch_device) + + outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) + generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + '''Liana Barrientos, 39, pleaded not guilty to two counts of "offering a false instrument" Prosecutors say the marriages were part of an immigration scam. In total, Barriento has been married 10 times, with nine of her marriages occurring between 1999 and 2002.''' + ], + ) + + @slow + def test_contrastive_search_t5(self): + article = """ New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. +A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. +Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. +In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. +Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the +2010 marriage license application, according to court documents. +Prosecutors said the marriages were part of an immigration scam. +On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. +After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective +Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. +All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. +Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. +Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. +The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s +Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. +Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. +If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18. +""" + article = "summarize: " + article.strip() + t5_tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-base-cnn-dm") + t5_model = T5ForConditionalGeneration.from_pretrained("flax-community/t5-base-cnn-dm").to(torch_device) + input_ids = t5_tokenizer(article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt").input_ids.to(torch_device) + + outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) + generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + '''Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for permanent residence after the marriages, prosecutors say.''' + ], + ) + + @slow + def test_contrastive_search_opt(self): + article = r"""A chat between a curious human and the Statue of Liberty. + +Human: What is your name? +Statue: I am the Statue of Liberty. +Human: Where do you live? +Statue: New York City. +Human: How long have you lived there?""" + + opt_tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-6.7b") + opt_model = OPTForCausalLM.from_pretrained("facebook/opt-6.7b").to(torch_device) + input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=5, max_length=256) + generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + '''A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived there?\nStatue: Since 1884.\nHuman: Why did you come to America?\nStatue: I was given to the United States by France as a gift for helping the French during the Franco-Prussian War.\nHuman: What do you think of America?\nStatue: I love it. It is the greatest country in the world.\nHuman: What’s the weather like in New York?\nStatue: It is cold.\nHuman: Is it safe to walk around at night?\nStatue: Yes. There are policemen everywhere.\nHuman: Do you have any children?\nStatue: Not yet. My pedestal is empty.\nHuman: What would you like to say to people who want to immigrate to America?\nStatue: Come on over. You will be happy here. We have everything you need.\nSource: http://www.statueofliberty.org/index.cf''' + ], + ) + + @slow + def test_contrastive_search_gptj(self): + article = """DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based""" + + opt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + opt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B").to(torch_device) + input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256) + generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + '''DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, United Kingdom with offices in Mountain View, San Francisco, New York City, Paris, Tokyo, Seoul, Beijing, Singapore, Tel Aviv, Dublin, Sydney, and Melbourne.[1]\n\nContents\n\nIn 2010, Google\'s parent company, Alphabet, announced a $500 million investment in DeepMind, with the aim of creating a company that would apply deep learning to problems in healthcare, energy, transportation, and other areas.[2]\n\nOn April 23, 2014, Google announced that it had acquired DeepMind for $400 million in cash and stock.[3] The acquisition was seen as a move to strengthen Google\'s position in the fast-growing field of artificial intelligence (AI), which it had invested in since 2010.[4] Google CEO Larry Page said that the company was "excited to have DeepMind on board" and that "this is a step towards our goal of building AI that works for everyone, not just a few".[5]\n\nDeepMind\'s co-founders, Demis Hassabis and Mustafa Suleyman, were named CEO and C''' + ], + ) + + @slow + def test_contrastive_search_gpt2(self): article = """DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based""" gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2-large") @@ -2069,6 +2183,128 @@ def test_max_new_tokens_encoder_decoder(self): with self.assertRaises(ValueError): bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20) + def test_max_new_tokens_decoder_only_contrastive_search_t5(self): + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + t5_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + t5_model = T5ForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-t5").to( + torch_device + ) + input_ids = t5_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + self.assertEqual(list(input_ids.shape), [1, 56]) + + max_new_tokens = 3 + t5_model.config.max_length = 20 + t5_model.config.eos_token_id = None + + # Encoder decoder call + outputs = t5_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4) + # 1 BOS + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 4]) + + # Decoder only call + outputs = t5_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4) + # 56 + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 59]) + + # Encoder decoder call > 20 + outputs = t5_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4) + + # 1 BOS + 20 + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 24]) + + # max_new_tokens and max_length serve the same purpose and must not be used together. + with self.assertRaises(ValueError): + t5_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4) + + def test_max_new_tokens_decoder_only_contrastive_search_bart(self): + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) + input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + self.assertEqual(list(input_ids.shape), [1, 29]) + + max_new_tokens = 3 + bart_model.config.max_length = 20 + bart_model.config.eos_token_id = None + + # Encoder decoder call + outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4) + # 1 BOS + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 4]) + + # Decoder only call + outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4) + # 29 + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 32]) + + # Encoder decoder call > 20 + outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4) + + # 1 BOS + 20 + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 24]) + + # max_new_tokens and max_length serve the same purpose and must not be used together. + with self.assertRaises(ValueError): + bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4) + + def test_max_new_tokens_decoder_only_contrastive_search_gptj(self): + article = """Justin Timberlake.""" + gptj_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gptj") + gptj_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gptj").to(torch_device) + input_ids = gptj_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + self.assertEqual(list(input_ids.shape), [1, 9]) + + max_new_tokens = 3 + gptj_model.config.max_length = 20 + + # call < 20 + outputs = gptj_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4) + + # 9 input_ids + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 12]) + + # call > 20 + outputs = gptj_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4) + + # 1 BOS token + 23 new tokens + self.assertEqual(list(outputs.shape), [1, 24]) + + # max_new_tokens and max_length serve the same purpose and must not be used together. + with self.assertRaises(ValueError): + gptj_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4) + + def test_max_new_tokens_decoder_only_contrastive_search_gpt2(self): + article = """Justin Timberlake.""" + gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + gpt2_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + self.assertEqual(list(input_ids.shape), [1, 9]) + + max_new_tokens = 3 + gpt2_model.config.max_length = 20 + + # call < 20 + outputs = gpt2_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4) + + # 9 input_ids + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 12]) + + # call > 20 + outputs = gpt2_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4) + + # 1 BOS token + 23 new tokens + self.assertEqual(list(outputs.shape), [1, 24]) + + # max_new_tokens and max_length serve the same purpose and must not be used together. + with self.assertRaises(ValueError): + gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4) + def test_max_new_tokens_decoder_only(self): article = """Justin Timberlake.""" gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") From ced9f70a390b96dc6f8b4db3eb077c2732ea8605 Mon Sep 17 00:00:00 2001 From: gmftbyGMFTBY <18811371908@163.com> Date: Wed, 19 Oct 2022 13:39:50 +0800 Subject: [PATCH 17/17] do the style, quality, consistency checks --- src/transformers/generation_utils.py | 17 +++++++-- tests/generation/test_generation_utils.py | 46 ++++++++++++++--------- 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index f813bffc5a088..06ead6f771953 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1879,10 +1879,13 @@ def contrastive_search( if step_counter == 0: # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save the `encoder_outputs` output = self(**model_inputs, output_hidden_states=True, output_attentions=True) - + # past_key_values is activated for fast decoding if "past_key_values" not in output: - raise ValueError(f"self.__class__ cannot return `past_key_values` and can therefore **not** be used for contrastive search.") + raise ValueError( + "self.__class__ cannot return `past_key_values` and can therefore **not** be used for" + " contrastive search." + ) past_key_values = output.past_key_values # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with previous tokens) @@ -1913,7 +1916,10 @@ def contrastive_search( for item in layer: bsz, num_head, seq_len, esz = item.size() item = ( - item.unsqueeze(1).expand(-1, top_k, -1, -1, -1).reshape(bsz * top_k, num_head, seq_len, esz).contiguous() + item.unsqueeze(1) + .expand(-1, top_k, -1, -1, -1) + .reshape(bsz * top_k, num_head, seq_len, esz) + .contiguous() ) # [bsz*beam, num_head, seq_len, esz] items.append(item) new_key_values.append(items) @@ -1945,7 +1951,10 @@ def contrastive_search( output = self(output_hidden_states=True, **next_model_inputs) if "past_key_values" not in output: - raise ValueError(f"self.__class__ cannot return `past_key_values` and can therefore **not** be used for contrastive search.") + raise ValueError( + "self.__class__ cannot return `past_key_values` and can therefore **not** be used for contrastive" + " search." + ) past_key_values = output.past_key_values logits = output.logits[:, -1, :] diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index c623c7502b00d..d2347bc0aaf45 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -27,18 +27,18 @@ import torch from transformers import ( + AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, - AutoModelForCausalLM, - OPTForCausalLM, - T5ForConditionalGeneration, BartForConditionalGeneration, BartTokenizer, GPT2LMHeadModel, GPT2Tokenizer, ImageGPTForCausalImageModeling, + OPTForCausalLM, Speech2TextForConditionalGeneration, SpeechEncoderDecoderModel, + T5ForConditionalGeneration, VisionEncoderDecoderModel, pipeline, top_k_top_p_filtering, @@ -1698,7 +1698,7 @@ def test_diverse_beam_search(self): @slow def test_contrastive_search_bart(self): - article = """ New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. + article = """ New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. @@ -1718,7 +1718,9 @@ def test_contrastive_search_bart(self): """ bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device) - input_ids = bart_tokenizer(article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt").input_ids.to(torch_device) + input_ids = bart_tokenizer( + article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt" + ).input_ids.to(torch_device) outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) @@ -1726,13 +1728,13 @@ def test_contrastive_search_bart(self): self.assertListEqual( generated_text, [ - '''Liana Barrientos, 39, pleaded not guilty to two counts of "offering a false instrument" Prosecutors say the marriages were part of an immigration scam. In total, Barriento has been married 10 times, with nine of her marriages occurring between 1999 and 2002.''' + """Liana Barrientos, 39, pleaded not guilty to two counts of "offering a false instrument" Prosecutors say the marriages were part of an immigration scam. In total, Barriento has been married 10 times, with nine of her marriages occurring between 1999 and 2002.""" ], ) @slow def test_contrastive_search_t5(self): - article = """ New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. + article = """ New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. @@ -1753,7 +1755,9 @@ def test_contrastive_search_t5(self): article = "summarize: " + article.strip() t5_tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-base-cnn-dm") t5_model = T5ForConditionalGeneration.from_pretrained("flax-community/t5-base-cnn-dm").to(torch_device) - input_ids = t5_tokenizer(article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt").input_ids.to(torch_device) + input_ids = t5_tokenizer( + article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt" + ).input_ids.to(torch_device) outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True) @@ -1761,7 +1765,7 @@ def test_contrastive_search_t5(self): self.assertListEqual( generated_text, [ - '''Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for permanent residence after the marriages, prosecutors say.''' + """Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for permanent residence after the marriages, prosecutors say.""" ], ) @@ -1785,7 +1789,7 @@ def test_contrastive_search_opt(self): self.assertListEqual( generated_text, [ - '''A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived there?\nStatue: Since 1884.\nHuman: Why did you come to America?\nStatue: I was given to the United States by France as a gift for helping the French during the Franco-Prussian War.\nHuman: What do you think of America?\nStatue: I love it. It is the greatest country in the world.\nHuman: What’s the weather like in New York?\nStatue: It is cold.\nHuman: Is it safe to walk around at night?\nStatue: Yes. There are policemen everywhere.\nHuman: Do you have any children?\nStatue: Not yet. My pedestal is empty.\nHuman: What would you like to say to people who want to immigrate to America?\nStatue: Come on over. You will be happy here. We have everything you need.\nSource: http://www.statueofliberty.org/index.cf''' + """A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived there?\nStatue: Since 1884.\nHuman: Why did you come to America?\nStatue: I was given to the United States by France as a gift for helping the French during the Franco-Prussian War.\nHuman: What do you think of America?\nStatue: I love it. It is the greatest country in the world.\nHuman: What’s the weather like in New York?\nStatue: It is cold.\nHuman: Is it safe to walk around at night?\nStatue: Yes. There are policemen everywhere.\nHuman: Do you have any children?\nStatue: Not yet. My pedestal is empty.\nHuman: What would you like to say to people who want to immigrate to America?\nStatue: Come on over. You will be happy here. We have everything you need.\nSource: http://www.statueofliberty.org/index.cf""" ], ) @@ -1803,7 +1807,7 @@ def test_contrastive_search_gptj(self): self.assertListEqual( generated_text, [ - '''DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, United Kingdom with offices in Mountain View, San Francisco, New York City, Paris, Tokyo, Seoul, Beijing, Singapore, Tel Aviv, Dublin, Sydney, and Melbourne.[1]\n\nContents\n\nIn 2010, Google\'s parent company, Alphabet, announced a $500 million investment in DeepMind, with the aim of creating a company that would apply deep learning to problems in healthcare, energy, transportation, and other areas.[2]\n\nOn April 23, 2014, Google announced that it had acquired DeepMind for $400 million in cash and stock.[3] The acquisition was seen as a move to strengthen Google\'s position in the fast-growing field of artificial intelligence (AI), which it had invested in since 2010.[4] Google CEO Larry Page said that the company was "excited to have DeepMind on board" and that "this is a step towards our goal of building AI that works for everyone, not just a few".[5]\n\nDeepMind\'s co-founders, Demis Hassabis and Mustafa Suleyman, were named CEO and C''' + """DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, United Kingdom with offices in Mountain View, San Francisco, New York City, Paris, Tokyo, Seoul, Beijing, Singapore, Tel Aviv, Dublin, Sydney, and Melbourne.[1]\n\nContents\n\nIn 2010, Google\'s parent company, Alphabet, announced a $500 million investment in DeepMind, with the aim of creating a company that would apply deep learning to problems in healthcare, energy, transportation, and other areas.[2]\n\nOn April 23, 2014, Google announced that it had acquired DeepMind for $400 million in cash and stock.[3] The acquisition was seen as a move to strengthen Google\'s position in the fast-growing field of artificial intelligence (AI), which it had invested in since 2010.[4] Google CEO Larry Page said that the company was "excited to have DeepMind on board" and that "this is a step towards our goal of building AI that works for everyone, not just a few".[5]\n\nDeepMind\'s co-founders, Demis Hassabis and Mustafa Suleyman, were named CEO and C""" ], ) @@ -2186,9 +2190,7 @@ def test_max_new_tokens_encoder_decoder(self): def test_max_new_tokens_decoder_only_contrastive_search_t5(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" t5_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") - t5_model = T5ForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-t5").to( - torch_device - ) + t5_model = T5ForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-t5").to(torch_device) input_ids = t5_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) self.assertEqual(list(input_ids.shape), [1, 56]) @@ -2203,7 +2205,9 @@ def test_max_new_tokens_decoder_only_contrastive_search_t5(self): self.assertEqual(list(outputs.shape), [1, 4]) # Decoder only call - outputs = t5_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4) + outputs = t5_model.generate( + decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4 + ) # 56 + 3 new tokens self.assertEqual(list(outputs.shape), [1, 59]) @@ -2215,7 +2219,9 @@ def test_max_new_tokens_decoder_only_contrastive_search_t5(self): # max_new_tokens and max_length serve the same purpose and must not be used together. with self.assertRaises(ValueError): - t5_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4) + t5_model.generate( + decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4 + ) def test_max_new_tokens_decoder_only_contrastive_search_bart(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" @@ -2237,7 +2243,9 @@ def test_max_new_tokens_decoder_only_contrastive_search_bart(self): self.assertEqual(list(outputs.shape), [1, 4]) # Decoder only call - outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4) + outputs = bart_model.generate( + decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4 + ) # 29 + 3 new tokens self.assertEqual(list(outputs.shape), [1, 32]) @@ -2249,7 +2257,9 @@ def test_max_new_tokens_decoder_only_contrastive_search_bart(self): # max_new_tokens and max_length serve the same purpose and must not be used together. with self.assertRaises(ValueError): - bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4) + bart_model.generate( + decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4 + ) def test_max_new_tokens_decoder_only_contrastive_search_gptj(self): article = """Justin Timberlake."""