Skip to content

Commit

Permalink
Generate: deprecate default max_length (huggingface#18018)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored and oneraghavan committed Sep 26, 2022
1 parent 0719e84 commit 27aae72
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 93 deletions.
53 changes: 44 additions & 9 deletions src/transformers/generation_flax_utils.py
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.


import warnings
from functools import partial
from typing import Dict, Optional

Expand Down Expand Up @@ -163,6 +164,7 @@ def generate(
self,
input_ids: jnp.ndarray,
max_length: Optional[int] = None,
max_new_tokens: Optional[int] = None,
pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
Expand Down Expand Up @@ -209,8 +211,12 @@ def generate(
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
max_length (`int`, *optional*, defaults to 20):
The maximum length of the sequence to be generated.
max_length (`int`, *optional*, defaults to `model.config.max_length`):
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
`max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in
the prompt.
max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
temperature (`float`, *optional*, defaults to 1.0):
Expand Down Expand Up @@ -258,8 +264,6 @@ def generate(
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
```"""
# set init values
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
Expand All @@ -270,11 +274,6 @@ def generate(

if decoder_start_token_id is None and self.config.is_encoder_decoder:
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
if min_length is not None and min_length > max_length:
raise ValueError(
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
f"length ({max_length})"
)

if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs
Expand All @@ -283,6 +282,42 @@ def generate(
# prepare decoder_input_ids for generation
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

# Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
if max_length is None and max_new_tokens is None:
warnings.warn(
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to "
f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is "
"deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend "
"using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif max_length is None and max_new_tokens is not None:
max_length = max_new_tokens + input_ids_seq_length
elif max_length is not None and max_new_tokens is not None:
raise ValueError(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
" documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
# default to config if still None
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length

if min_length is not None and min_length > max_length:
raise ValueError(
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
f"length ({max_length})"
)
if input_ids_seq_length >= max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {max_length}. This can lead to unexpected behavior. You should consider increasing"
"`max_new_tokens`."
)

do_sample = do_sample if do_sample is not None else self.config.do_sample
num_beams = num_beams if num_beams is not None else self.config.num_beams

Expand Down

0 comments on commit 27aae72

Please sign in to comment.