diff --git a/src/transformers/generation_flax_utils.py b/src/transformers/generation_flax_utils.py index 59d66a0fe2b4c..2f80c7fcf27e9 100644 --- a/src/transformers/generation_flax_utils.py +++ b/src/transformers/generation_flax_utils.py @@ -15,6 +15,7 @@ # limitations under the License. +import warnings from functools import partial from typing import Dict, Optional @@ -163,6 +164,7 @@ def generate( self, input_ids: jnp.ndarray, max_length: Optional[int] = None, + max_new_tokens: Optional[int] = None, pad_token_id: Optional[int] = None, bos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, @@ -209,8 +211,12 @@ def generate( input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. - max_length (`int`, *optional*, defaults to 20): - The maximum length of the sequence to be generated. + max_length (`int`, *optional*, defaults to `model.config.max_length`): + The maximum length the generated tokens can have. Corresponds to the length of the input prompt + + `max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in + the prompt. + max_new_tokens (`int`, *optional*): + The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. do_sample (`bool`, *optional*, defaults to `False`): Whether or not to use sampling ; use greedy decoding otherwise. temperature (`float`, *optional*, defaults to 1.0): @@ -258,8 +264,6 @@ def generate( >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ```""" # set init values - max_length = max_length if max_length is not None else self.config.max_length - min_length = min_length if min_length is not None else self.config.min_length bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id @@ -270,11 +274,6 @@ def generate( if decoder_start_token_id is None and self.config.is_encoder_decoder: raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.") - if min_length is not None and min_length > max_length: - raise ValueError( - f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum " - f"length ({max_length})" - ) if self.config.is_encoder_decoder: # add encoder_outputs to model_kwargs @@ -283,6 +282,42 @@ def generate( # prepare decoder_input_ids for generation input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + # Prepare `max_length` depending on other stopping criteria. + input_ids_seq_length = input_ids.shape[-1] + if max_length is None and max_new_tokens is None: + warnings.warn( + "Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to " + f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is " + "deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend " + "using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif max_length is None and max_new_tokens is not None: + max_length = max_new_tokens + input_ids_seq_length + elif max_length is not None and max_new_tokens is not None: + raise ValueError( + "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" + " limit to the generated output length. Remove one of those arguments. Please refer to the" + " documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + # default to config if still None + max_length = max_length if max_length is not None else self.config.max_length + min_length = min_length if min_length is not None else self.config.min_length + + if min_length is not None and min_length > max_length: + raise ValueError( + f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum " + f"length ({max_length})" + ) + if input_ids_seq_length >= max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {max_length}. This can lead to unexpected behavior. You should consider increasing" + "`max_new_tokens`." + ) + do_sample = do_sample if do_sample is not None else self.config.do_sample num_beams = num_beams if num_beams is not None else self.config.num_beams diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 1a9b19e2b595b..ec9a61e90099f 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -15,6 +15,7 @@ # limitations under the License. import inspect +import warnings from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union @@ -53,8 +54,8 @@ class TFGreedySearchDecoderOnlyOutput(ModelOutput): if all batches finished early due to the `eos_token_id`. scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. `(max_length-input_ids.shape[-1],)`-shaped tuple of `tf.Tensor` with each tensor - of shape `(batch_size, config.vocab_size)`). + at each generation step. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each + generated token), with each tensor of shape `(batch_size, config.vocab_size)`. attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. @@ -83,8 +84,8 @@ class TFGreedySearchEncoderDecoderOutput(ModelOutput): if all batches finished early due to the `eos_token_id`. scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. `(max_length-1,)`-shaped tuple of `tf.Tensor` with each tensor of shape - `(batch_size, config.vocab_size)`). + at each generation step. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each + generated token), with each tensor of shape `(batch_size, config.vocab_size)`. encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple of `tf.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. @@ -123,8 +124,8 @@ class TFSampleDecoderOnlyOutput(ModelOutput): if all batches finished early due to the `eos_token_id`. scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. `(max_length-input_ids.shape[-1],)`-shaped tuple of `tf.Tensor` with each tensor - of shape `(batch_size*num_return_sequences, config.vocab_size)`). + at each generation step. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each + generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`. attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `tf.Tensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length, sequence_length)`. @@ -153,8 +154,8 @@ class TFSampleEncoderDecoderOutput(ModelOutput): if all batches finished early due to the `eos_token_id`. scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. `(max_length-1,)`-shaped tuple of `tf.Tensor` with each tensor of shape - `(batch_size*num_return_sequences, config.vocab_size)`). + at each generation step. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each + generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`. encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple of `tf.Tensor` (one for each layer of the decoder) of shape `(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`. @@ -194,9 +195,9 @@ class TFBeamSearchDecoderOnlyOutput(ModelOutput): Final beam scores of the generated `sequences`. scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log - softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam - . `(max_length-input_ids.shape[-1],)`-shaped tuple of `tf.Tensor` with each tensor of shape - `(batch_size*num_beams*num_return_sequences, config.vocab_size)`). + softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this + beam. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token), + with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`. attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `tf.Tensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. @@ -227,9 +228,9 @@ class TFBeamSearchEncoderDecoderOutput(ModelOutput): Final beam scores of the generated `sequences`. scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log - softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam - . `(max_length-1,)`-shaped tuple of `tf.Tensor` with each tensor of shape `(batch_size*num_beams, - config.vocab_size)`). + softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this + beam. `Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token), + with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple of `tf.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, @@ -272,9 +273,9 @@ class TFBeamSampleDecoderOnlyOutput(ModelOutput): Final beam scores of the generated `sequences`. scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log - softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam - . `(max_length-input_ids.shape[-1],)`-shaped tuple of `tf.Tensor` with each tensor of shape - `(batch_size*num_beams*num_return_sequences, config.vocab_size)`). + softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this + beam. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token), + with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`. attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `tf.Tensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. @@ -305,9 +306,9 @@ class TFBeamSampleEncoderDecoderOutput(ModelOutput): Final beam scores of the generated `sequences`. scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log - softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam - . `(max_length-1,)`-shaped tuple of `tf.Tensor` with each tensor of shape `(batch_size*num_beams, - config.vocab_size)`). + softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this + beam. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token), + with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple of `tf.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. @@ -375,6 +376,7 @@ def generate( self, input_ids=None, max_length=None, + max_new_tokens=None, min_length=None, do_sample=None, early_stopping=None, @@ -423,8 +425,12 @@ def generate( method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`, `input_values`, `input_features`, or `pixel_values`. - max_length (`int`, *optional*, defaults to 20): - The maximum length of the sequence to be generated. + max_length (`int`, *optional*, defaults to `model.config.max_length`): + The maximum length the generated tokens can have. Corresponds to the length of the input prompt + + `max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in + the prompt. + max_new_tokens (`int`, *optional*): + The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. min_length (`int`, *optional*, defaults to 10): The minimum length of the sequence to be generated. do_sample (`bool`, *optional*, defaults to `False`): @@ -577,6 +583,7 @@ def generate( return self._generate( input_ids=input_ids, max_length=max_length, + max_new_tokens=max_new_tokens, min_length=min_length, do_sample=do_sample, early_stopping=early_stopping, @@ -1286,6 +1293,7 @@ def _generate( self, input_ids=None, max_length=None, + max_new_tokens=None, min_length=None, do_sample=None, early_stopping=None, @@ -1332,8 +1340,12 @@ def _generate( input_ids (`tf.Tensor` of `dtype=tf.int32` and shape `(batch_size, sequence_length)`, *optional*): The sequence used as a prompt for the generation. If `None` the method initializes it with `bos_token_id` and a batch size of 1. - max_length (`int`, *optional*, defaults to 20): - The maximum length of the sequence to be generated. + max_length (`int`, *optional*, defaults to `model.config.max_length`): + The maximum length the generated tokens can have. Corresponds to the length of the input prompt + + `max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in + the prompt. + max_new_tokens (`int`, *optional*): + The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. min_length (`int`, *optional*, defaults to 10): The minimum length of the sequence to be generated. do_sample (`bool`, *optional*, defaults to `False`): @@ -1474,8 +1486,6 @@ def _generate( outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) ```""" # 1. Set generation parameters if not already defined - max_length = max_length if max_length is not None else self.config.max_length - min_length = min_length if min_length is not None else self.config.min_length length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping @@ -1514,12 +1524,6 @@ def _generate( logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence") pad_token_id = eos_token_id - if min_length is not None and min_length > max_length: - raise ValueError( - f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum " - f"length ({max_length})" - ) - use_xla = not tf.executing_eagerly() if use_xla and not self.supports_xla_generation: raise ValueError( @@ -1561,21 +1565,49 @@ def _generate( model_kwargs=model_kwargs, ) - if input_ids.shape[-1] >= max_length: + # 5. Prepare `max_length` depending on other stopping criteria. + input_ids_seq_length = input_ids.shape[-1] + if max_length is None and max_new_tokens is None: + warnings.warn( + "Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to " + f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is " + "deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend " + "using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif max_length is None and max_new_tokens is not None: + max_length = max_new_tokens + input_ids_seq_length + elif max_length is not None and max_new_tokens is not None: raise ValueError( - f"The context has {input_ids.shape[-1]} number of tokens, " - f"but `max_length` is only {max_length}. " - "Please make sure that `max_length` is bigger than the number of tokens, " - "by setting either `generate(max_length=...,...)` or `config.max_length = ...`" + "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" + " limit to the generated output length. Remove one of those arguments. Please refer to the" + " documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + # default to config if still None + max_length = max_length if max_length is not None else self.config.max_length + min_length = min_length if min_length is not None else self.config.min_length + + if min_length is not None and min_length > max_length: + raise ValueError( + f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum " + f"length ({max_length})" + ) + if input_ids_seq_length >= max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {max_length}. This can lead to unexpected behavior. You should consider increasing" + "`max_new_tokens`." ) - # 5. determine generation mode + # 6. determine generation mode # TODO(Matt, Joao, Patrick) - add more use cases here is_greedy_gen_mode = (num_beams == 1) and do_sample is False is_sample_gen_mode = (num_beams == 1) and do_sample is True is_beam_gen_mode = (num_beams > 1) and do_sample is False - # 6. prepare distribution pre_processing samplers + # 7. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, @@ -1587,13 +1619,13 @@ def _generate( forced_eos_token_id=forced_eos_token_id, ) - # 7. go into different generation modes + # 8. go into different generation modes if is_greedy_gen_mode: if num_return_sequences > 1: raise ValueError( f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." ) - # 8. run greedy search + # 9. run greedy search return self.greedy_search( input_ids, max_length=max_length, @@ -1605,10 +1637,10 @@ def _generate( **model_kwargs, ) elif is_sample_gen_mode: - # 8. prepare logits warper + # 9. prepare logits warper logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature) - # 9. expand input_ids with `num_return_sequences` additional sequences per batch + # 10. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_return_sequences, @@ -1616,7 +1648,7 @@ def _generate( **model_kwargs, ) - # 10. run sample + # 11. run sample return self.sample( input_ids, logits_processor=logits_processor, @@ -1637,7 +1669,7 @@ def _generate( f"num_beams >= num_return_sequences, got {num_beams} and {num_return_sequences} (respectivelly)" ) - # 8. broadcast inputs to the desired number of beams + # 9. broadcast inputs to the desired number of beams input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams) if "encoder_outputs" in model_kwargs: @@ -1650,7 +1682,7 @@ def _generate( model_kwargs["attention_mask"], num_beams=num_beams ) - # 9. run beam search + # 10. run beam search return self.beam_search( input_ids, max_length=max_length, diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 1792545e45476..4a2f9dfdfe17b 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -70,8 +70,8 @@ class GreedySearchDecoderOnlyOutput(ModelOutput): if all batches finished early due to the `eos_token_id`. scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each - tensor of shape `(batch_size, config.vocab_size)`). + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. @@ -100,8 +100,8 @@ class GreedySearchEncoderDecoderOutput(ModelOutput): if all batches finished early due to the `eos_token_id`. scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape - `(batch_size, config.vocab_size)`). + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. @@ -140,8 +140,8 @@ class SampleDecoderOnlyOutput(ModelOutput): if all batches finished early due to the `eos_token_id`. scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each - tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`). + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length, @@ -171,8 +171,8 @@ class SampleEncoderDecoderOutput(ModelOutput): if all batches finished early due to the `eos_token_id`. scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape - `(batch_size*num_return_sequences, config.vocab_size)`). + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`. encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`. @@ -214,8 +214,8 @@ class BeamSearchDecoderOnlyOutput(ModelOutput): scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape - `(batch_size*num_beams*num_return_sequences, config.vocab_size)`). + Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), + with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`. beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Beam indices of generated token id at each generation step. `torch.LongTensor` of shape `(batch_size*num_return_sequences, input_ids.shape[-1])`. @@ -251,8 +251,8 @@ class BeamSearchEncoderDecoderOutput(ModelOutput): scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams, - config.vocab_size)`). + Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), + with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Beam indices of generated token id at each generation step. `torch.LongTensor` of shape `(batch_size*num_return_sequences, max_length-1)`. @@ -300,8 +300,8 @@ class BeamSampleDecoderOnlyOutput(ModelOutput): scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape - `(batch_size*num_beams*num_return_sequences, config.vocab_size)`). + Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), + with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`. beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Beam indices of generated token id at each generation step. `torch.LongTensor` of shape `(batch_size*num_return_sequences, input_ids.shape[-1])`. @@ -337,8 +337,8 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams, - config.vocab_size)`). + Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), + with each tensor of shape `(batch_size*num_beams, config.vocab_size)`). beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Beam indices of generated token id at each generation step. `torch.LongTensor` of shape `(batch_size*num_return_sequences, max_length-1)`. @@ -923,10 +923,11 @@ def generate( should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`, `input_values`, `input_features`, or `pixel_values`. max_length (`int`, *optional*, defaults to `model.config.max_length`): - The maximum length of the sequence to be generated. - max_new_tokens (`int`, *optional*, defaults to None): - The maximum numbers of tokens to generate, ignore the current number of tokens. Use either - `max_new_tokens` or `max_length` but not both, they serve the same purpose. + The maximum length the generated tokens can have. Corresponds to the length of the input prompt + + `max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in + the prompt. + max_new_tokens (`int`, *optional*): + The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. min_length (`int`, *optional*, defaults to 10): The minimum length of the sequence to be generated. do_sample (`bool`, *optional*, defaults to `False`): @@ -974,7 +975,7 @@ def generate( where one can allow different forms of each word. num_return_sequences(`int`, *optional*, defaults to 1): The number of independently computed returned sequences for each element in the batch. - max_time(`float`, *optional*, defaults to None): + max_time(`float`, *optional*): The maximum amount of time you allow the computation to run for in seconds. generation will still finish the current pass after allocated time has been passed. attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1195,20 +1196,25 @@ def generate( # if decoder-only then inputs_tensor has to be `input_ids` input_ids = inputs_tensor + # 5. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] - - # 5. Prepare `max_length` depending on other stopping criteria - # if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens` - if max_length is None and max_new_tokens is not None: - max_length = max_new_tokens + input_ids_seq_length - elif max_length is not None and max_new_tokens is not None: - # Both are set, this is odd, raise a warning + if max_length is None and max_new_tokens is None: warnings.warn( - "Both `max_length` and `max_new_tokens` have been set " - f"but they serve the same purpose. `max_length` {max_length} " - f"will take priority over `max_new_tokens` {max_new_tokens}.", + "Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to " + f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is " + "deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend " + "using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) + elif max_length is None and max_new_tokens is not None: + max_length = max_new_tokens + input_ids_seq_length + elif max_length is not None and max_new_tokens is not None: + raise ValueError( + "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" + " limit to the generated output length. Remove one of those arguments. Please refer to the" + " documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) # default to config if still None max_length = max_length if max_length is not None else self.config.max_length min_length = min_length if min_length is not None else self.config.min_length @@ -1221,9 +1227,9 @@ def generate( if input_ids_seq_length >= max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but ``max_length`` is set to" - f" {max_length}. This can lead to unexpected behavior. You should consider increasing" - " ``config.max_length`` or ``max_length``." + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {max_length}. This can lead to unexpected behavior. You should consider increasing " + "`max_new_tokens`." ) # 6. determine generation mode diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index ce12d631bf39c..56227403ae60b 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -2023,8 +2023,8 @@ def test_max_new_tokens_encoder_decoder(self): # 1 BOS + 20 + 3 new tokens self.assertEqual(list(outputs.shape), [1, 24]) - # max_new_tokens and max_length serve the same purpose and should not be used together. - with self.assertWarns(UserWarning): + # max_new_tokens and max_length serve the same purpose and must not be used together. + with self.assertRaises(ValueError): bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20) def test_max_new_tokens_decoder_only(self): @@ -2050,8 +2050,8 @@ def test_max_new_tokens_decoder_only(self): # 1 BOS token + 23 new tokens self.assertEqual(list(outputs.shape), [1, 24]) - # max_new_tokens and max_length serve the same purpose and should not be used together. - with self.assertWarns(UserWarning): + # max_new_tokens and max_length serve the same purpose and must not be used together. + with self.assertRaises(ValueError): gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20) def test_encoder_decoder_generate_with_inputs_embeds(self):