Skip to content

Commit

Permalink
Fixed the docstring and type hint for forced_decoder_ids option in Ge… (
Browse files Browse the repository at this point in the history
  • Loading branch information
koreyou committed Oct 17, 2022
1 parent f2ecb9e commit 82e360b
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 18 deletions.
7 changes: 4 additions & 3 deletions src/transformers/generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,10 +735,11 @@ def __call__(self, input_ids, scores):


class ForceTokensLogitsProcessor(LogitsProcessor):
r"""This processor can be used to force a list of tokens. The processor will set their log probs to `inf` so that they
are sampled at their corresponding index."""
r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
indices that will be forced before sampling. The processor will set their log probs to `inf` so that they are
sampled at their corresponding index."""

def __init__(self, force_token_map):
def __init__(self, force_token_map: List[List[int]]):
self.force_token_map = dict(force_token_map)

def __call__(self, input_ids, scores):
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/generation_tf_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,10 +547,11 @@ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.


class TFForceTokensLogitsProcessor(TFLogitsProcessor):
r"""This processor can be used to force a list of tokens. The processor will set their log probs to `0` and all
other tokens to `-inf` so that they are sampled at their corresponding index."""
r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
indices that will be forced before sampling. The processor will set their log probs to `0` and all other tokens to
`-inf` so that they are sampled at their corresponding index."""

def __init__(self, force_token_map):
def __init__(self, force_token_map: List[List[int]]):
force_token_map = dict(force_token_map)
# Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the
# index of the array corresponds to the index of the token to be forced, for XLA compatibility.
Expand Down
17 changes: 10 additions & 7 deletions src/transformers/generation_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def generate(
forced_eos_token_id=None,
suppress_tokens: Optional[List[int]] = None,
begin_suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[List[int]]] = None,
**model_kwargs,
) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]:
r"""
Expand Down Expand Up @@ -506,8 +506,10 @@ def generate(
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
logit processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of tokens that will be forced as beginning tokens, before sampling.
forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of pairs of integers which indicates a mapping from generation indices to token indices that
will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
be a token of index 123.
model_specific_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model.
Expand Down Expand Up @@ -1493,9 +1495,10 @@ def _generate(
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
logit processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of tokens that will be forced as beginning tokens.
forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of pairs of integers which indicates a mapping from generation indices to token indices that
will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
be a token of index 123.
model_kwargs:
Additional model specific kwargs will be forwarded to the `call` function of the model.
Expand Down Expand Up @@ -2147,7 +2150,7 @@ def _get_logits_processor(
forced_eos_token_id: int,
suppress_tokens: Optional[List[int]] = None,
begin_suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[List[int]]] = None,
) -> TFLogitsProcessorList:
"""
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]
Expand Down
11 changes: 6 additions & 5 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ def _get_logits_processor(
renormalize_logits: Optional[bool],
suppress_tokens: Optional[List[int]] = None,
begin_suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[List[int]]] = None,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
Expand Down Expand Up @@ -956,7 +956,7 @@ def generate(
exponential_decay_length_penalty: Optional[Tuple[int, float]] = None,
suppress_tokens: Optional[List[int]] = None,
begin_suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[List[int]]] = None,
**model_kwargs,
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -1121,9 +1121,10 @@ def generate(
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
logit processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of tokens that will be forced as beginning tokens, before sampling.
forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of pairs of integers which indicates a mapping from generation indices to token indices that
will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
be a token of index 123.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
Expand Down

0 comments on commit 82e360b

Please sign in to comment.