From 481a87882c48e3de6692731072f8eeba759ee031 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 23 Jun 2022 12:28:08 +0100 Subject: [PATCH] TF: generate without `tf.TensorArray` (#17801) --- src/transformers/generation_tf_utils.py | 283 ++++++------------ .../models/gpt2/modeling_tf_gpt2.py | 5 +- .../models/xlnet/modeling_tf_xlnet.py | 9 +- 3 files changed, 97 insertions(+), 200 deletions(-) diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index f27a772c084f16..6d3d105cced62e 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -16,7 +16,6 @@ import inspect from dataclasses import dataclass -from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -1979,6 +1978,8 @@ def greedy_search( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) use_xla = not tf.executing_eagerly() + # some models, like XLNet, need more than the last token in the presence of past + needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys()) # 2. init `attentions`, `hidden_states`, and `scores` tuples scores = [] if (return_dict_in_generate and output_scores) else None @@ -1989,34 +1990,25 @@ def greedy_search( # 3. init tensors to use for "xla-compileable" generate function batch_size, cur_len = input_ids.shape - # initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences` - generated = tf.TensorArray( - element_shape=(batch_size,), - dtype=tf.int32, - dynamic_size=False, - size=max_length, - clear_after_read=False, - ) - if pad_token_id: # ignores the cases when it is 0 or None - for i in range(max_length): - generated = generated.write(i, tf.broadcast_to(pad_token_id, (batch_size,))) - - # write prompt to generated - for i in range(cur_len): - generated = generated.write(i, input_ids[:, i]) - + # initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences` + input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0) + generated = tf.concat([input_ids, input_ids_padding], axis=-1) finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) # 4. define "xla-compile-able" stop-condition and auto-regressive function # define condition fn - def greedy_search_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): + def greedy_search_cond_fn(generated, finished_sequences, cur_len, model_kwargs): """state termination condition fn.""" return ~tf.reduce_all(finished_sequences) # define condition fn - def greedy_search_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): + def greedy_search_body_fn(generated, finished_sequences, cur_len, model_kwargs): """state update fn.""" - model_inputs = self.prepare_inputs_for_generation(next_tokens, **model_kwargs) + if model_kwargs.get("past") is None or needs_full_input: + input_ids = generated[:, :cur_len] + else: + input_ids = tf.expand_dims(generated[:, cur_len - 1], -1) + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token logits outputs = self( **model_inputs, @@ -2043,8 +2035,7 @@ def greedy_search_body_fn(generated, finished_sequences, next_tokens, cur_len, m decoder_hidden_states.append(outputs.hidden_states) # pre-process distribution - input_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size))) - next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len) + next_tokens_scores = logits_processor(generated, next_token_logits, cur_len) # argmax next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32) @@ -2057,8 +2048,8 @@ def greedy_search_body_fn(generated, finished_sequences, next_tokens, cur_len, m finished_sequences = finished_sequences | (next_tokens == eos_token_id) # update `generated` and `cur_len` - generated = generated.write(cur_len, next_tokens) - next_tokens = next_tokens[:, None] + update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1) + generated = tf.tensor_scatter_nd_update(tensor=generated, indices=update_indices, updates=next_tokens) cur_len += 1 # update model_kwargs @@ -2073,34 +2064,29 @@ def greedy_search_body_fn(generated, finished_sequences, next_tokens, cur_len, m # let's throw out `past` since we don't want `None` tensors model_kwargs.pop("past", None) - next_tokens = tf.reshape(generated.concat(), (-1, batch_size)) - next_tokens = tf.transpose(next_tokens[:cur_len]) - - return generated, finished_sequences, next_tokens, cur_len, model_kwargs + return generated, finished_sequences, cur_len, model_kwargs # 5. run generation # 1st generation step has to be run before to initialize `past` - generated, finished_sequences, next_tokens, cur_len, model_kwargs = greedy_search_body_fn( - generated, finished_sequences, input_ids, cur_len, model_kwargs + generated, finished_sequences, cur_len, model_kwargs = greedy_search_body_fn( + generated, finished_sequences, cur_len, model_kwargs ) # 2-to-n generation steps can then be run in autoregressive fashion # only in case 1st generation step does NOT yield EOS token though - if greedy_search_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): + if greedy_search_cond_fn(generated, finished_sequences, cur_len, model_kwargs): maximum_iterations = max_length - cur_len - generated, _, _, cur_len, _ = tf.while_loop( + generated, _, cur_len, _ = tf.while_loop( greedy_search_cond_fn, greedy_search_body_fn, - (generated, finished_sequences, next_tokens, cur_len, model_kwargs), + (generated, finished_sequences, cur_len, model_kwargs), maximum_iterations=maximum_iterations, ) # 6. prepare outputs - output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size))) - if not use_xla: # cut for backward compatibility - output_ids = output_ids[:, :cur_len] + generated = generated[:, :cur_len] if return_dict_in_generate: if self.config.is_encoder_decoder: @@ -2117,7 +2103,7 @@ def greedy_search_body_fn(generated, finished_sequences, next_tokens, cur_len, m decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None return TFGreedySearchEncoderDecoderOutput( - sequences=output_ids, + sequences=generated, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, @@ -2127,13 +2113,13 @@ def greedy_search_body_fn(generated, finished_sequences, next_tokens, cur_len, m ) else: return TFGreedySearchDecoderOnlyOutput( - sequences=output_ids, + sequences=generated, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: - return output_ids + return generated def sample( self, @@ -2250,6 +2236,8 @@ def sample( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) use_xla = not tf.executing_eagerly() + # some models, like XLNet, need more than the last token in the presence of past + needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys()) # 2. init `attentions`, `hidden_states`, and `scores` tuples scores = [] if (return_dict_in_generate and output_scores) else None @@ -2261,29 +2249,20 @@ def sample( batch_size, cur_len = input_ids.shape # initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences` - generated = tf.TensorArray( - element_shape=(batch_size,), - dtype=tf.int32, - dynamic_size=False, - size=max_length, - clear_after_read=False, - ) - if pad_token_id: # ignores the cases when it is 0 or None - for i in range(max_length): - generated = generated.write(i, tf.broadcast_to(pad_token_id, (batch_size,))) - - # write prompt to generated - for i in range(cur_len): - generated = generated.write(i, input_ids[:, i]) - + input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0) + generated = tf.concat([input_ids, input_ids_padding], axis=-1) finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) # 4. define "xla-compile-able" stop-condition and auto-regressive function - def sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): + def sample_cond_fn(generated, finished_sequences, cur_len, model_kwargs): return ~tf.reduce_all(finished_sequences) - def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): - model_inputs = self.prepare_inputs_for_generation(next_tokens, **model_kwargs) + def sample_body_fn(generated, finished_sequences, cur_len, model_kwargs): + if model_kwargs.get("past") is None or needs_full_input: + input_ids = generated[:, :cur_len] + else: + input_ids = tf.expand_dims(generated[:, cur_len - 1], -1) + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token logits outputs = self( **model_inputs, @@ -2310,9 +2289,8 @@ def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kw decoder_hidden_states.append(outputs.hidden_states) # pre-process distribution - input_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size))) - next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len) - next_tokens_scores = logits_warper(input_ids, next_tokens_scores, cur_len) + next_tokens_scores = logits_processor(generated, next_token_logits, cur_len) + next_tokens_scores = logits_warper(generated, next_tokens_scores, cur_len) # sample if seed is not None: @@ -2334,8 +2312,8 @@ def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kw finished_sequences = finished_sequences | (next_tokens == eos_token_id) # update `generated` and `cur_len` - generated = generated.write(cur_len, next_tokens) - next_tokens = next_tokens[:, None] + update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1) + generated = tf.tensor_scatter_nd_update(tensor=generated, indices=update_indices, updates=next_tokens) cur_len += 1 # update model_kwargs @@ -2350,34 +2328,29 @@ def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kw # let's throw out `past` since we don't want `None` tensors model_kwargs.pop("past", None) - next_tokens = tf.reshape(generated.concat(), (-1, batch_size)) - next_tokens = tf.transpose(next_tokens[:cur_len]) - - return generated, finished_sequences, next_tokens, cur_len, model_kwargs + return generated, finished_sequences, cur_len, model_kwargs # 5. run generation # 1st generation step has to be run before to initialize `past` - generated, finished_sequences, next_tokens, cur_len, model_kwargs = sample_body_fn( - generated, finished_sequences, input_ids, cur_len, model_kwargs + generated, finished_sequences, cur_len, model_kwargs = sample_body_fn( + generated, finished_sequences, cur_len, model_kwargs ) # 2-to-n generation steps can then be run in autoregressive fashion # only in case 1st generation step does NOT yield EOS token though - if sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): + if sample_cond_fn(generated, finished_sequences, cur_len, model_kwargs): maximum_iterations = max_length - cur_len - generated, _, _, cur_len, _ = tf.while_loop( + generated, _, cur_len, _ = tf.while_loop( sample_cond_fn, sample_body_fn, - (generated, finished_sequences, next_tokens, cur_len, model_kwargs), + (generated, finished_sequences, cur_len, model_kwargs), maximum_iterations=maximum_iterations, ) # 6. prepare outputs - output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size))) - if not use_xla: # cut for backward compatibility - output_ids = output_ids[:, :cur_len] + generated = generated[:, :cur_len] if return_dict_in_generate: if self.config.is_encoder_decoder: @@ -2394,7 +2367,7 @@ def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kw decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None return TFSampleEncoderDecoderOutput( - sequences=output_ids, + sequences=generated, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, @@ -2404,13 +2377,13 @@ def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kw ) else: return TFSampleDecoderOnlyOutput( - sequences=output_ids, + sequences=generated, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: - return output_ids + return generated def beam_search( self, @@ -2585,6 +2558,8 @@ def gather_fn(tensor): # GPT2 and other models has a slightly different cache structure, with a different batch axis model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self) cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0 + # some models, like XLNet, need more than the last token in the presence of past + needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys()) # 2. init `attentions`, `hidden_states`, and `scores` tuples scores = [] if (return_dict_in_generate and output_scores) else None @@ -2594,41 +2569,13 @@ def gather_fn(tensor): # 3. init tensors to use for "xla-compileable" generate function batch_size, num_beams, cur_len = input_ids.shape - input_ids_length = cur_len # per batch, beam-item holding current token in loop, pre-populated with `pad_token_id` - sequences = tf.TensorArray( - element_shape=(batch_size, num_beams), - dtype=tf.int32, - dynamic_size=False, - size=max_length, - clear_after_read=False, - ) - running_sequences = tf.TensorArray( - element_shape=(batch_size, num_beams), - dtype=tf.int32, - dynamic_size=False, - size=max_length, - clear_after_read=False, - ) - intermediary_running_sequences = tf.TensorArray( - element_shape=(batch_size, num_beams * 2), - dtype=tf.int32, - dynamic_size=False, - size=max_length, - clear_after_read=False, + input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * ( + pad_token_id or 0 ) - if pad_token_id: # ignores the cases when it is 0 or None - for i in range(max_length): - sequences = sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams))) - running_sequences = running_sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams))) - intermediary_running_sequences = intermediary_running_sequences.write( - i, tf.broadcast_to(pad_token_id, (batch_size, num_beams * 2)) - ) - - # write prompt to running_sequences - for i in range(cur_len): - running_sequences = running_sequences.write(i, input_ids[:, :, i]) + running_sequences = tf.concat([input_ids, input_ids_padding], axis=-1) + sequences = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * (pad_token_id or 0) # per batch,beam-item state bit indicating if sentence has finished. is_sent_finished = tf.zeros((batch_size, num_beams), dtype=tf.bool) @@ -2656,7 +2603,6 @@ def beam_search_cond_fn( sequences, scores, is_sent_finished, - input_ids_length, model_kwargs, ): """ @@ -2685,27 +2631,18 @@ def beam_search_body_fn( sequences, scores, is_sent_finished, - input_ids_length, model_kwargs, - intermediary_running_sequences=None, ): """ Beam Search iterative update function -- each iteration adds a new token and updates the best sequences seen so far """ - # TODO (joao): this loop is probably faster with gather/scatters, instead of using `tf.TensorArray`. - # Alternativelly, attempt to rewrite function with permuted axis, when enabling XLA. - # 1. Forward current tokens - - # TF places the dynamic dimension (seq_len) in the first axis, we want it in the last - running_sequences_seq_last = tf.transpose(running_sequences.stack(), perm=[1, 2, 0]) - input_token = tf.slice( - running_sequences_seq_last, - (0, 0, cur_len - input_ids_length), - (batch_size, num_beams, input_ids_length), - ) - model_inputs = self.prepare_inputs_for_generation(flatten_beam_dim(input_token), **model_kwargs) + if model_kwargs.get("past") is None or needs_full_input: + input_ids = running_sequences[:, :, :cur_len] + else: + input_ids = tf.expand_dims(running_sequences[:, :, cur_len - 1], -1) + model_inputs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), **model_kwargs) model_outputs = self( **model_inputs, return_dict=True, @@ -2734,9 +2671,7 @@ def beam_search_body_fn( # get log probabilities from logits, process logits with processors (*e.g.* min_length, ...), and # add new logprobs to existing running logprobs scores. log_probs = tf.nn.log_softmax(logits) - log_probs = logits_processor( - flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len - ) + log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len) log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) log_probs = log_probs + tf.expand_dims(running_scores, axis=2) vocab_size = log_probs.shape[2] @@ -2755,23 +2690,28 @@ def beam_search_body_fn( beams_to_keep = 2 * num_beams topk_log_probs, topk_indices = tf.math.top_k(log_probs, k=beams_to_keep) topk_beam_indices = topk_indices // vocab_size - topk_running_sequences_seq_last = gather_beams(running_sequences_seq_last, topk_beam_indices) + topk_running_sequences = gather_beams(running_sequences, topk_beam_indices) topk_ids = topk_indices % vocab_size # writes the new token - intermediary_running_sequences = intermediary_running_sequences.unstack( - tf.transpose(topk_running_sequences_seq_last, perm=[2, 0, 1]) + indices_batch = tf.repeat(tf.range(batch_size), [beams_to_keep]) + indices_beam = tf.tile(tf.range(beams_to_keep), [batch_size]) + update_indices = tf.stack( + [indices_batch, indices_beam, tf.broadcast_to(cur_len, [batch_size * beams_to_keep])], axis=-1 + ) + topk_sequences = tf.tensor_scatter_nd_update( + tensor=topk_running_sequences, + indices=update_indices, + updates=tf.reshape(topk_ids, [batch_size * beams_to_keep]), ) - topk_sequences = intermediary_running_sequences.write(cur_len, topk_ids) - topk_sequences_seq_last = tf.transpose(topk_sequences.stack(), perm=[1, 2, 0]) # 4. Check which sequences have ended # Update current sequences: Did the top `num_beams` sequences reach an end marker? # To prevent these just finished sequences from being added to the current sequences # set of active beam search sequences, set their log probs to a very large negative value. - eos_in_next_token = topk_sequences_seq_last[:, :, cur_len] == eos_token_id + eos_in_next_token = topk_sequences[:, :, cur_len] == eos_token_id if eos_token_id is None: - eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences_seq_last[:, :, cur_len].shape) + eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences[:, :, cur_len].shape) did_topk_just_finished = eos_in_next_token & tf.broadcast_to( tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0), eos_in_next_token.shape, @@ -2785,8 +2725,8 @@ def beam_search_body_fn( # Determine the top k beam indices (from top 2*k beams) from log probs and gather top k beams # (from top 2*k beams). next_topk_indices = tf.math.top_k(running_topk_log_probs, k=num_beams)[1] - next_running_sequences_seq_last, next_running_scores = gather_beams( - [topk_sequences_seq_last, running_topk_log_probs], next_topk_indices + next_running_sequences, next_running_scores = gather_beams( + [topk_sequences, running_topk_log_probs], next_topk_indices ) # 6. Process topk logits @@ -2807,18 +2747,18 @@ def beam_search_body_fn( # 7. Get scores, sequences, is sentence finished for next. # Combine sequences, scores, and flags along the beam dimension and compare new finished sequence scores # to existing finished scores and select the best from the new set of beams - sequences_seq_last = tf.transpose(sequences.stack(), perm=[1, 2, 0]) - merged_sequences = tf.concat([sequences_seq_last, topk_sequences_seq_last], axis=1) + merged_sequences = tf.concat([sequences, topk_sequences], axis=1) merged_scores = tf.concat([scores, topk_log_probs], axis=1) merged_is_sent_finished = tf.concat([is_sent_finished, did_topk_just_finished], axis=1) topk_merged_indices = tf.math.top_k(merged_scores, k=num_beams)[1] - next_sequences_seq_last, next_scores, next_is_sent_finished = gather_beams( + next_sequences, next_scores, next_is_sent_finished = gather_beams( [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices ) # 8. Prepare data for the next iteration # Determine the top k beam indices from the original set of all beams. With these, gather the top k # beam-associated caches. + cur_len = cur_len + 1 if "past_key_values" in model_outputs: cache = tf.nest.map_structure( lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=cache_batch_axis), @@ -2841,35 +2781,20 @@ def beam_search_body_fn( # if we don't cache past key values we need the whole input if model_kwargs.get("past", None) is None: - next_input_ids_length = cur_len + 1 # let's throw out `past` since we don't want `None` tensors model_kwargs.pop("past", None) - else: - next_input_ids_length = 1 - - # 9. Prepare the `tf.TensorArray` for the next iteration - next_sequences = sequences.unstack(tf.transpose(next_sequences_seq_last, perm=[2, 0, 1])) - next_running_sequences = running_sequences.unstack( - tf.transpose(next_running_sequences_seq_last, perm=[2, 0, 1]) - ) return ( - cur_len + 1, + cur_len, next_running_sequences, next_running_scores, next_sequences, next_scores, next_is_sent_finished, - next_input_ids_length, next_model_kwargs, ) # 5. run generation - # Adds the `intermediary_running_sequences` TensorArray into the body, needed as a scratchpad - beam_search_body_fn = partial( - beam_search_body_fn, intermediary_running_sequences=intermediary_running_sequences - ) - # 1st generation step has to be run before to initialize `past` (if active) ( cur_len, @@ -2878,66 +2803,38 @@ def beam_search_body_fn( sequences, scores, is_sent_finished, - input_ids_length, model_kwargs, ) = beam_search_body_fn( - cur_len, - running_sequences, - running_scores, - sequences, - scores, - is_sent_finished, - input_ids_length, - model_kwargs, + cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs ) # 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does # NOT yield EOS token though) if beam_search_cond_fn( - cur_len, - running_sequences, - running_scores, - sequences, - scores, - is_sent_finished, - input_ids_length, - model_kwargs, + cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs ): maximum_iterations = max_length - cur_len - cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _, _ = tf.while_loop( + cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _ = tf.while_loop( beam_search_cond_fn, beam_search_body_fn, - ( - cur_len, - running_sequences, - running_scores, - sequences, - scores, - is_sent_finished, - input_ids_length, - model_kwargs, - ), + (cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs), maximum_iterations=maximum_iterations, ) # 6. prepare outputs - # convert the sequneces to tf.Tensor with shape (batch_size, num_beams, seq_len) - sequences_seq_last = tf.transpose(sequences.stack(), perm=[1, 2, 0]) - running_sequences_seq_last = tf.transpose(running_sequences.stack(), perm=[1, 2, 0]) - # Account for the edge-case where there are no finished sequences for a particular batch item. If so, return # running sequences for that batch item. none_finished = tf.math.reduce_any(is_sent_finished, axis=1) - sequences_seq_last = tf.where(none_finished[:, None, None], sequences_seq_last, running_sequences_seq_last) + sequences = tf.where(none_finished[:, None, None], sequences, running_sequences) scores = tf.where(none_finished[:, None], scores, running_scores) # Take best beams for each batch (the score is sorted in ascending order) - sequences_seq_last = flatten_beam_dim(sequences_seq_last[:, :num_return_sequences, :]) + sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :]) scores = flatten_beam_dim(scores[:, :num_return_sequences]) if not use_xla: # Cut for backward compatibility - sequences_seq_last = sequences_seq_last[:, :cur_len] + sequences = sequences[:, :cur_len] if return_dict_in_generate: if self.config.is_encoder_decoder: @@ -2948,7 +2845,7 @@ def beam_search_body_fn( ) return TFBeamSearchEncoderDecoderOutput( - sequences=sequences_seq_last, + sequences=sequences, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, @@ -2958,13 +2855,13 @@ def beam_search_body_fn( ) else: return TFBeamSearchDecoderOnlyOutput( - sequences=sequences_seq_last, + sequences=sequences, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: - return sequences_seq_last + return sequences def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty): diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index b3d1ad048498e3..a2cd338f6d4791 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -874,8 +874,9 @@ def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current new_past = [None for _ in range(len(past))] slice_start_base = tf.constant([0, 0, 0, 1, 0]) attention_mask_update_slice = tf.ones((batch_size, 1), dtype=attention_mask.dtype) - # correct 5 here - new_past_index = current_pos - 1 + # -1 because current_pos has already been incremented before this function + # -1 again because last index = len - 1 + new_past_index = current_pos - 2 for i in range(len(past)): update_slice = past[i][:, :, :, -1:] diff --git a/src/transformers/models/xlnet/modeling_tf_xlnet.py b/src/transformers/models/xlnet/modeling_tf_xlnet.py index df4111d2631727..312a02712349a3 100644 --- a/src/transformers/models/xlnet/modeling_tf_xlnet.py +++ b/src/transformers/models/xlnet/modeling_tf_xlnet.py @@ -1202,7 +1202,6 @@ def get_prefix_bias_name(self): def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwargs): # Add dummy token at the end (no attention on this one) - effective_batch_size = inputs.shape[0] dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype) @@ -1212,12 +1211,12 @@ def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwar offset = 2 if past: - inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1) + input_ids = tf.concat([inputs[:, -offset:], dummy_token], axis=1) else: - inputs = tf.concat([inputs, dummy_token], axis=1) + input_ids = tf.concat([inputs, dummy_token], axis=1) # Build permutation mask so that previous tokens don't see last token - sequence_length = inputs.shape[1] + sequence_length = input_ids.shape[1] perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1)) perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1)) perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1) @@ -1228,7 +1227,7 @@ def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwar target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1) inputs = { - "input_ids": inputs, + "input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping, "use_mems": use_mems,