diff --git a/docs/source/en/internal/generation_utils.mdx b/docs/source/en/internal/generation_utils.mdx index bdb6c7c59ce32..31db57740eca5 100644 --- a/docs/source/en/internal/generation_utils.mdx +++ b/docs/source/en/internal/generation_utils.mdx @@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License. This page lists all the utility functions used by [`~generation_utils.GenerationMixin.generate`], [`~generation_utils.GenerationMixin.greedy_search`], +[`~generation_utils.GenerationMixin.contrastive_search`], [`~generation_utils.GenerationMixin.sample`], [`~generation_utils.GenerationMixin.beam_search`], [`~generation_utils.GenerationMixin.beam_sample`], diff --git a/docs/source/en/main_classes/text_generation.mdx b/docs/source/en/main_classes/text_generation.mdx index 94deeeae89411..2fc7950cdbe1c 100644 --- a/docs/source/en/main_classes/text_generation.mdx +++ b/docs/source/en/main_classes/text_generation.mdx @@ -26,6 +26,7 @@ Each framework has a generate method for auto-regressive text generation impleme - sample - beam_search - beam_sample + - contrastive_search - group_beam_search - constrained_beam_search diff --git a/examples/pytorch/text-generation/run_generation_contrastive_search.py b/examples/pytorch/text-generation/run_generation_contrastive_search.py new file mode 100755 index 0000000000000..117f063a6dd9a --- /dev/null +++ b/examples/pytorch/text-generation/run_generation_contrastive_search.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" The examples of running contrastive search on the auto-APIs; + +Running this example: +python run_generation_contrastive_search.py --model_name_or_path=gpt2-large --penalty_alpha=0.6 --k=4 --length=256 +""" + + +import argparse +import logging + +import numpy as np +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer + + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + + +def set_seed(args): + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + ) + parser.add_argument("--prompt", type=str, default="") + parser.add_argument("--length", type=int, default=20) + parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="temperature of 1.0 has no effect, lower tend toward greedy sampling", + ) + parser.add_argument( + "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2" + ) + parser.add_argument("--k", type=int, default=0) + parser.add_argument("--penalty_alpha", type=float, default=0.0) + parser.add_argument("--p", type=float, default=0.9) + + parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.") + parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.") + parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.") + + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") + parser.add_argument( + "--fp16", + action="store_true", + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", + ) + args = parser.parse_args() + + args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() + + logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}") + + set_seed(args) + + # Initialize the model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) + model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path) + + # tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path) + # model = OPTForCausalLM.from_pretrained(args.model_name_or_path) + model.to(args.device) + + if args.fp16: + model.half() + + logger.info(args) + prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") + + inputs = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False) + inputs = {key: value.to(args.device) for key, value in inputs.items()} + + output_sequences = model.generate( + **inputs, + max_length=args.length + len(inputs["input_ids"][0]), + penalty_alpha=args.penalty_alpha, + top_k=args.k, + ) + + generated_sequences = [] + for generated_sequence_idx, generated_sequence in enumerate(output_sequences): + print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") + generated_sequence = generated_sequence.tolist() + + # Decode text + text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, add_special_tokens=False) + + # Remove all text after the stop token + text = text[: text.find(args.stop_token) if args.stop_token else None] + + # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing + total_sequence = ( + prompt_text + text[len(tokenizer.decode(inputs["input_ids"][0], clean_up_tokenization_spaces=True)) :] + ) + + generated_sequences.append(total_sequence) + print(total_sequence) + + return generated_sequences + + +if __name__ == "__main__": + main() diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index ad533a06f1db5..06ead6f771953 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -54,6 +54,7 @@ StoppingCriteriaList, validate_stopping_criteria, ) +from .modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput from .models.auto import ( MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, @@ -96,6 +97,54 @@ class GreedySearchDecoderOnlyOutput(ModelOutput): hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None +@dataclass +class ContrastiveSearchEncoderDecoderOutput(ModelOutput): + """ + Args: + Base class for outputs of decoder-only generation models using contrastive search. + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when + `config.output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` + is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + scores: Optional[Tuple[torch.FloatTensor]] = None + decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +@dataclass +class ContrastiveSearchDecoderOnlyOutput(ModelOutput): + """ + Args: + Base class for outputs of decoder-only generation models using contrastive search. + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when + `config.output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is + passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + scores: Optional[Tuple[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + @dataclass class GreedySearchEncoderDecoderOutput(ModelOutput): """ @@ -393,6 +442,8 @@ class GenerationMixin: The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for: - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and `do_sample=False`. + - *contrastive search* by calling [`~generation_utils.GenerationMixin.contrastive_search`] if `penalty_alpha>0` + and `top_k>1` - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and `do_sample=True`. - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and @@ -921,6 +972,7 @@ def generate( early_stopping: Optional[bool] = None, num_beams: Optional[int] = None, temperature: Optional[float] = None, + penalty_alpha: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, typical_p: Optional[float] = None, @@ -966,6 +1018,8 @@ def generate( - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and `do_sample=False`. + - *contrastive search* by calling [`~generation_utils.GenerationMixin.contrastive_search`] if + `penalty_alpha>0.` and `top_k>1` - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and `do_sample=True`. - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and @@ -1011,6 +1065,8 @@ def generate( Number of beams for beam search. 1 means no beam search. temperature (`float`, *optional*, defaults to `model.config.temperature` or 1.0 if the config does not set any value): The value used to module the next token probabilities. + penalty_alpha (`float`, *optional*, defaults to `model.config.penalty_alpha` or None if the config does not set any value): + The values balance the model confidence and the degeneration penalty in contrastive search decoding. top_k (`int`, *optional*, defaults to `model.config.top_k` or 50 if the config does not set any value): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`, *optional*, defaults to `model.config.top_p` or 1.0 if the config does not set any value): @@ -1329,19 +1385,45 @@ def generate( # 6. determine generation mode is_constraint_gen_mode = constraints is not None or force_words_ids is not None + + is_contrastive_search_gen_mode = ( + top_k is not None and top_k > 1 and do_sample is False and penalty_alpha is not None and penalty_alpha > 0 + ) + is_greedy_gen_mode = ( - (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode + (num_beams == 1) + and (num_beam_groups == 1) + and do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode ) is_sample_gen_mode = ( - (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode + (num_beams == 1) + and (num_beam_groups == 1) + and do_sample is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode ) is_beam_gen_mode = ( - (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode + (num_beams > 1) + and (num_beam_groups == 1) + and do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode ) is_beam_sample_gen_mode = ( - (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode + (num_beams > 1) + and (num_beam_groups == 1) + and do_sample is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_group_beam_gen_mode = ( + (num_beams > 1) + and (num_beam_groups > 1) + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode ) - is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode if num_beam_groups > num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") @@ -1411,6 +1493,27 @@ def generate( **model_kwargs, ) + elif is_contrastive_search_gen_mode: + + if num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {num_return_sequences} when doing contrastive search." + ) + + return self.contrastive_search( + input_ids, + top_k=top_k, + penalty_alpha=penalty_alpha, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_sample_gen_mode: # 10. prepare logits warper logits_warper = self._get_logits_warper( @@ -1646,6 +1749,324 @@ def typeerror(): **model_kwargs, ) + @torch.no_grad() + def contrastive_search( + self, + input_ids: torch.LongTensor, + top_k: Optional[int] = 1, + penalty_alpha: Optional[float] = 0, + logits_processor: Optional[LogitsProcessorList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[GreedySearchOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **contrastive search** and can + be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + top_k (`int`, *optional*, defaults to 1): + The size of the candidate set that is used to re-rank for contrastive search + penalty_alpha (`float`, *optional*, defaults to 0): + The degeneration penalty for contrastive search; activate when it is larger than 0 + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation_utils.ContrastiveSearchDecoderOnlyOutput`], + [`~generation_utils.ContrastiveSearchEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` + containing the generated tokens (default behaviour) or a + [`~generation_utils.ContrastiveSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation_utils.ContrastiveSearchEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... MinLengthLogitsProcessor, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + >>> model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") + >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token + >>> model.config.pad_token_id = model.config.eos_token_id + >>> input_prompt = "DeepMind Company is" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt") + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=64)]) + >>> outputs = model.contrastive_search( + ... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria + ... ) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ["DeepMind Company is a company that focuses on the development and commercialization of artificial intelligence (AI). DeepMind’s mission is to help people understand and solve problems that are difficult to solve in the world today.\n\nIn this post, we talk about the benefits of deep learning in business and how it"] + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + this_peer_finished = False # used by synced_gpus only + + step_counter = 0 + while True: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # prepare inputs + model_kwargs["use_cache"] = True + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values; (2) last_hidden_states; (3) logit_for_next_step + if step_counter == 0: + # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save the `encoder_outputs` + output = self(**model_inputs, output_hidden_states=True, output_attentions=True) + + # past_key_values is activated for fast decoding + if "past_key_values" not in output: + raise ValueError( + "self.__class__ cannot return `past_key_values` and can therefore **not** be used for" + " contrastive search." + ) + past_key_values = output.past_key_values + + # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with previous tokens) + if self.config.is_encoder_decoder: + last_hidden_states = output.decoder_hidden_states[-1] + else: + last_hidden_states = output.hidden_states[-1] + # next logit for contrastive search to select top-k candidate tokens + logit_for_next_step = output.logits[:, -1, :] + + # contrastive_search main logic start: + # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by degeneration penalty + bsz, seqlen, embed_dim = last_hidden_states.size() + + # logits processor + logit_for_next_step = logits_processor(input_ids, logit_for_next_step) + logit_for_next_step = logits_warper(input_ids, logit_for_next_step) + next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) + + _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=top_k) + top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) + + # enlarge the past_key_values + new_key_values = [] + for layer in past_key_values: + items = [] + # item is either the key or the value matrix + for item in layer: + bsz, num_head, seq_len, esz = item.size() + item = ( + item.unsqueeze(1) + .expand(-1, top_k, -1, -1, -1) + .reshape(bsz * top_k, num_head, seq_len, esz) + .contiguous() + ) # [bsz*beam, num_head, seq_len, esz] + items.append(item) + new_key_values.append(items) + past_key_values = new_key_values + + # build next attention mask + if "attention_mask" in model_inputs: + attention_mask = model_inputs["attention_mask"] # [B, S] + # decoder-only model need the full attention mask, not only the mask for the last token + if self.config.is_encoder_decoder is False: + attention_mask = torch.cat([attention_mask, attention_mask.new_ones((bsz, 1))], dim=-1) + attention_mask = attention_mask.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, attention_mask.size(-1)) + else: + attention_mask = None + + # encoder-decoder model also contains the `encoder_outputs` + if self.config.is_encoder_decoder and "encoder_outputs" in model_inputs: + encoder_outputs = model_inputs["encoder_outputs"] + else: + encoder_outputs = None + next_model_inputs = self.prepare_inputs_for_generation( + top_k_ids.view(-1, 1), + past=past_key_values, + attention_mask=attention_mask, + use_cache=True, + encoder_outputs=encoder_outputs, + ) + # compute the candidate tokens by the language model and collects their hidden_states + output = self(output_hidden_states=True, **next_model_inputs) + + if "past_key_values" not in output: + raise ValueError( + "self.__class__ cannot return `past_key_values` and can therefore **not** be used for contrastive" + " search." + ) + past_key_values = output.past_key_values + + logits = output.logits[:, -1, :] + # name is different for encoder-decoder and decoder-only models + if self.config.is_encoder_decoder: + next_hidden = output.decoder_hidden_states[-1] + full_hidden_states = output.decoder_hidden_states + else: + next_hidden = output.hidden_states[-1] + full_hidden_states = output.hidden_states + context_hidden = ( + last_hidden_states.unsqueeze(1).expand(-1, top_k, -1, -1).reshape(bsz * top_k, seqlen, embed_dim) + ) + + # compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the model confidence + # the scores and index of the selected tokens are returned + selected_scores, selected_idx = ranking_fast( + context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k + ) + + # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores (model confidence minus degeneration penalty); (6) decoder hidden_states + next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx] + next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) + next_hidden = next_hidden[range(bsz), selected_idx, :] + last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) + + decoder_hidden_states_one_step = [] + for layer in full_hidden_states: + layer = torch.stack(torch.split(layer.squeeze(dim=1), top_k)) + layer = layer[range(bsz), selected_idx, :] + decoder_hidden_states_one_step.append(layer) + + # select the past_key_value + new_key_values = [] + for layer in past_key_values: + items = [] + # item is either the key or the value matrix + for item in layer: + bsz_and_beam, num_head, seq_len, esz = item.size() + bsz = int(bsz_and_beam // top_k) + item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz] + item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz] + items.append(item) + new_key_values.append(items) + past_key_values = new_key_values + + logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(bsz), selected_idx, :] + # contrastive_search main logic end:: + # after running the above codes, we update following parameters: next_tokens, past_key_values, logit_for_next_step, selected_score, decoder_hidden_states_one_step + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (selected_scores,) + + if output_hidden_states: + decoder_hidden_states += (decoder_hidden_states_one_step,) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if self.config.is_encoder_decoder: + outputs = Seq2SeqLMOutput( + past_key_values=past_key_values, + ) + else: + outputs = CausalLMOutputWithCrossAttentions( + past_key_values=past_key_values, attentions=model_kwargs["attention_mask"] + ) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + # prepare model inputs + model_kwargs["past_key_values"] = past_key_values + step_counter += 1 + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return ContrastiveSearchEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return ContrastiveSearchDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + hidden_states=decoder_hidden_states, + ) + else: + return input_ids + def greedy_search( self, input_ids: torch.LongTensor, @@ -3457,3 +3878,25 @@ def top_k_top_p_filtering( ) return logits + + +def ranking_fast( + context_hidden: torch.FloatTensor, + next_hidden: torch.FloatTensor, + next_top_k_probs: torch.FloatTensor, + alpha: float, + beam_width: int, +) -> Tuple[torch.FloatTensor]: + """ + context_hidden: bsz*beam x seqlen x embed_dim next_hidden: bsz*beam x 1 x embed_dim next_top_k_probs: bsz x beam + """ + _, context_len, embed_dim = context_hidden.size() + norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) + norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) + cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] + degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K] + next_top_k_probs = next_top_k_probs.view(-1) # [B*K] + contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty + contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] + selected_scores, selected_idx = contrastive_score.max(dim=-1) # [B] + return torch.log(selected_scores), selected_idx diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index f48cfff83cb85..d2347bc0aaf45 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -27,6 +27,7 @@ import torch from transformers import ( + AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, BartForConditionalGeneration, @@ -34,8 +35,10 @@ GPT2LMHeadModel, GPT2Tokenizer, ImageGPTForCausalImageModeling, + OPTForCausalLM, Speech2TextForConditionalGeneration, SpeechEncoderDecoderModel, + T5ForConditionalGeneration, VisionEncoderDecoderModel, pipeline, top_k_top_p_filtering, @@ -1693,6 +1696,140 @@ def test_diverse_beam_search(self): ], ) + @slow + def test_contrastive_search_bart(self): + article = """ New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. +A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. +Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. +In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. +Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the +2010 marriage license application, according to court documents. +Prosecutors said the marriages were part of an immigration scam. +On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. +After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective +Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. +All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. +Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. +Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. +The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s +Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. +Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. +If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18. +""" + bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") + bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device) + input_ids = bart_tokenizer( + article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt" + ).input_ids.to(torch_device) + + outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) + generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + """Liana Barrientos, 39, pleaded not guilty to two counts of "offering a false instrument" Prosecutors say the marriages were part of an immigration scam. In total, Barriento has been married 10 times, with nine of her marriages occurring between 1999 and 2002.""" + ], + ) + + @slow + def test_contrastive_search_t5(self): + article = """ New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. +A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. +Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. +In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. +Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the +2010 marriage license application, according to court documents. +Prosecutors said the marriages were part of an immigration scam. +On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. +After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective +Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. +All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. +Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. +Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. +The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s +Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. +Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. +If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18. +""" + article = "summarize: " + article.strip() + t5_tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-base-cnn-dm") + t5_model = T5ForConditionalGeneration.from_pretrained("flax-community/t5-base-cnn-dm").to(torch_device) + input_ids = t5_tokenizer( + article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt" + ).input_ids.to(torch_device) + + outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) + generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + """Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for permanent residence after the marriages, prosecutors say.""" + ], + ) + + @slow + def test_contrastive_search_opt(self): + article = r"""A chat between a curious human and the Statue of Liberty. + +Human: What is your name? +Statue: I am the Statue of Liberty. +Human: Where do you live? +Statue: New York City. +Human: How long have you lived there?""" + + opt_tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-6.7b") + opt_model = OPTForCausalLM.from_pretrained("facebook/opt-6.7b").to(torch_device) + input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=5, max_length=256) + generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + """A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived there?\nStatue: Since 1884.\nHuman: Why did you come to America?\nStatue: I was given to the United States by France as a gift for helping the French during the Franco-Prussian War.\nHuman: What do you think of America?\nStatue: I love it. It is the greatest country in the world.\nHuman: What’s the weather like in New York?\nStatue: It is cold.\nHuman: Is it safe to walk around at night?\nStatue: Yes. There are policemen everywhere.\nHuman: Do you have any children?\nStatue: Not yet. My pedestal is empty.\nHuman: What would you like to say to people who want to immigrate to America?\nStatue: Come on over. You will be happy here. We have everything you need.\nSource: http://www.statueofliberty.org/index.cf""" + ], + ) + + @slow + def test_contrastive_search_gptj(self): + article = """DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based""" + + opt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + opt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B").to(torch_device) + input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256) + generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + """DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, United Kingdom with offices in Mountain View, San Francisco, New York City, Paris, Tokyo, Seoul, Beijing, Singapore, Tel Aviv, Dublin, Sydney, and Melbourne.[1]\n\nContents\n\nIn 2010, Google\'s parent company, Alphabet, announced a $500 million investment in DeepMind, with the aim of creating a company that would apply deep learning to problems in healthcare, energy, transportation, and other areas.[2]\n\nOn April 23, 2014, Google announced that it had acquired DeepMind for $400 million in cash and stock.[3] The acquisition was seen as a move to strengthen Google\'s position in the fast-growing field of artificial intelligence (AI), which it had invested in since 2010.[4] Google CEO Larry Page said that the company was "excited to have DeepMind on board" and that "this is a step towards our goal of building AI that works for everyone, not just a few".[5]\n\nDeepMind\'s co-founders, Demis Hassabis and Mustafa Suleyman, were named CEO and C""" + ], + ) + + @slow + def test_contrastive_search_gpt2(self): + article = """DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based""" + + gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2-large") + gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(torch_device) + input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + outputs = gpt2_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256) + + generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + """DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as Google Now, which helps users find the information they\'re looking for on the web. But the company is not the only one to collect data on its users. Facebook, for example, has its own facial recognition technology, as well as a database of millions of photos that it uses to personalize its News Feed.\n\nFacebook\'s use of data is a hot topic in the tech industry, with privacy advocates concerned about the company\'s ability to keep users\' information private. In a blog post last year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, but said in a statement to The Associated Press that""" + ], + ) + def test_max_length_backward_compat_greedy(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") @@ -2050,6 +2187,134 @@ def test_max_new_tokens_encoder_decoder(self): with self.assertRaises(ValueError): bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20) + def test_max_new_tokens_decoder_only_contrastive_search_t5(self): + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + t5_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + t5_model = T5ForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-t5").to(torch_device) + input_ids = t5_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + self.assertEqual(list(input_ids.shape), [1, 56]) + + max_new_tokens = 3 + t5_model.config.max_length = 20 + t5_model.config.eos_token_id = None + + # Encoder decoder call + outputs = t5_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4) + # 1 BOS + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 4]) + + # Decoder only call + outputs = t5_model.generate( + decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4 + ) + # 56 + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 59]) + + # Encoder decoder call > 20 + outputs = t5_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4) + + # 1 BOS + 20 + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 24]) + + # max_new_tokens and max_length serve the same purpose and must not be used together. + with self.assertRaises(ValueError): + t5_model.generate( + decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4 + ) + + def test_max_new_tokens_decoder_only_contrastive_search_bart(self): + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) + input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + self.assertEqual(list(input_ids.shape), [1, 29]) + + max_new_tokens = 3 + bart_model.config.max_length = 20 + bart_model.config.eos_token_id = None + + # Encoder decoder call + outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4) + # 1 BOS + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 4]) + + # Decoder only call + outputs = bart_model.generate( + decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4 + ) + # 29 + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 32]) + + # Encoder decoder call > 20 + outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4) + + # 1 BOS + 20 + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 24]) + + # max_new_tokens and max_length serve the same purpose and must not be used together. + with self.assertRaises(ValueError): + bart_model.generate( + decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4 + ) + + def test_max_new_tokens_decoder_only_contrastive_search_gptj(self): + article = """Justin Timberlake.""" + gptj_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gptj") + gptj_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gptj").to(torch_device) + input_ids = gptj_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + self.assertEqual(list(input_ids.shape), [1, 9]) + + max_new_tokens = 3 + gptj_model.config.max_length = 20 + + # call < 20 + outputs = gptj_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4) + + # 9 input_ids + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 12]) + + # call > 20 + outputs = gptj_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4) + + # 1 BOS token + 23 new tokens + self.assertEqual(list(outputs.shape), [1, 24]) + + # max_new_tokens and max_length serve the same purpose and must not be used together. + with self.assertRaises(ValueError): + gptj_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4) + + def test_max_new_tokens_decoder_only_contrastive_search_gpt2(self): + article = """Justin Timberlake.""" + gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + gpt2_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + self.assertEqual(list(input_ids.shape), [1, 9]) + + max_new_tokens = 3 + gpt2_model.config.max_length = 20 + + # call < 20 + outputs = gpt2_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4) + + # 9 input_ids + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 12]) + + # call > 20 + outputs = gpt2_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4) + + # 1 BOS token + 23 new tokens + self.assertEqual(list(outputs.shape), [1, 24]) + + # max_new_tokens and max_length serve the same purpose and must not be used together. + with self.assertRaises(ValueError): + gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4) + def test_max_new_tokens_decoder_only(self): article = """Justin Timberlake.""" gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")