Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: deprecate generation relying on default max_length #18018

Merged
merged 8 commits into from Jul 23, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
51 changes: 42 additions & 9 deletions src/transformers/generation_flax_utils.py
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.


import warnings
from functools import partial
from typing import Dict, Optional

Expand Down Expand Up @@ -163,6 +164,7 @@ def generate(
self,
input_ids: jnp.ndarray,
max_length: Optional[int] = None,
max_new_tokens: Optional[int] = None,
pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
Expand Down Expand Up @@ -209,8 +211,11 @@ def generate(

input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
max_length (`int`, *optional*, defaults to 20):
The maximum length of the sequence to be generated.
max_length (`int`, *optional*, defaults to `model.config.max_length`):
The maximum length of the sequence to be generated. Prefer the use of `max_new_tokens`, which ignores
the number of tokens in the prompt.
gante marked this conversation as resolved.
Show resolved Hide resolved
max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
gante marked this conversation as resolved.
Show resolved Hide resolved
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
temperature (`float`, *optional*, defaults to 1.0):
Expand Down Expand Up @@ -258,8 +263,6 @@ def generate(
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
```"""
# set init values
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
Expand All @@ -270,11 +273,6 @@ def generate(

if decoder_start_token_id is None and self.config.is_encoder_decoder:
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
if min_length is not None and min_length > max_length:
raise ValueError(
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
f"length ({max_length})"
)

if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs
Expand All @@ -283,6 +281,41 @@ def generate(
# prepare decoder_input_ids for generation
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

# Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
if max_length is None and max_new_tokens is None:
warnings.warn(
gante marked this conversation as resolved.
Show resolved Hide resolved
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to "
f"{self.config.max_length} (`self.config.max_length`). This behavior is deprecated and will be "
gante marked this conversation as resolved.
Show resolved Hide resolved
"removed in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length "
gante marked this conversation as resolved.
Show resolved Hide resolved
"of the generation.",
UserWarning,
)
elif max_length is None and max_new_tokens is not None:
max_length = max_new_tokens + input_ids_seq_length
elif max_length is not None and max_new_tokens is not None:
raise ValueError(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
" documentation for more information."
gante marked this conversation as resolved.
Show resolved Hide resolved
)
# default to config if still None
max_length = max_length if max_length is not None else self.config.max_length
gante marked this conversation as resolved.
Show resolved Hide resolved
min_length = min_length if min_length is not None else self.config.min_length

if min_length is not None and min_length > max_length:
raise ValueError(
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
f"length ({max_length})"
)
if input_ids_seq_length >= max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {max_length}. This can lead to unexpected behavior. You should consider increasing"
"`max_new_tokens`."
)

do_sample = do_sample if do_sample is not None else self.config.do_sample
num_beams = num_beams if num_beams is not None else self.config.num_beams

Expand Down