From f9d35dfa056efd50400d23f40775baef4c3037ae Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 30 Jun 2022 18:04:53 +0000 Subject: [PATCH 1/2] tmp commit --- src/transformers/models/t5/modeling_tf_t5.py | 16 +++++++++++++--- tests/models/t5/test_modeling_tf_t5.py | 14 +++++--------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 5c8aec875b55d..a1a22d14b66c8 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -23,6 +23,7 @@ import numpy as np import tensorflow as tf +from tensorflow.compiler.tf2xla.python.xla import dynamic_slice from ...activations_tf import get_tf_activation from ...modeling_tf_outputs import ( @@ -384,10 +385,19 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): else: position_bias = self.compute_bias(real_seq_length, key_length) - # if key and values are already calculated - # we want only the last query position bias + # if key and values are already calculated we want only the last query position bias if past_key_value is not None: - position_bias = position_bias[:, :, -seq_length:, :] + if not self.has_relative_attention_bias: + position_bias = position_bias[:, :, -seq_length:, :] + else: + # we might have a padded past structure, in which case we want to fetch the position bias slice for + # the most recently populated index + most_recently_filled_past_index = tf.reduce_max(tf.where(past_key_value[0][0, 0, :, 0] != 0.0)) + position_bias = dynamic_slice( + position_bias, + (0, 0, most_recently_filled_past_index, 0), + (1, self.n_heads, seq_length, real_seq_length), + ) if mask is not None: position_bias = tf.cast(position_bias, dtype=mask.dtype) diff --git a/tests/models/t5/test_modeling_tf_t5.py b/tests/models/t5/test_modeling_tf_t5.py index c67851a054148..b87e460c4d762 100644 --- a/tests/models/t5/test_modeling_tf_t5.py +++ b/tests/models/t5/test_modeling_tf_t5.py @@ -590,21 +590,17 @@ def test_beam_search_xla_generate_simple(self): ] input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids - # xla_generate = tf.function(model.generate, jit_compile=True) - xla_generate = tf.function(model.generate) + xla_generate = tf.function(model.generate, jit_compile=True) - # TODO (joao): there is something not quite right with XLA T5 -- as we increase `max_length` the two outputs - # drift appart, where the XLA version clearly degrades its quality. XLA-related variables look fine (they are - # being padded and filled in the right places). This also happens in other generation modes. Investigate. - output_ids = model.generate(input_ids, num_beams=2, max_length=9) - output_ids_xla = xla_generate(input_ids, num_beams=2, max_length=9) + output_ids = model.generate(input_ids, num_beams=2) + output_ids_xla = xla_generate(input_ids, num_beams=2) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True) expected_output_string = [ - "Aujourd'hui est une belle journée.", - "J'ai quatre chats,", + "Aujourd'hui, c'est une belle journée.", + "J'ai quatre chats, trois chiens, deux oiseaux et un cheval.", ] self.assertListEqual(expected_output_string, output_strings) From c048a0b148484ec33ecd9502dddcc323e431760a Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 30 Jun 2022 18:15:42 +0000 Subject: [PATCH 2/2] get the right slicing index for position_bias --- src/transformers/models/t5/modeling_tf_t5.py | 6 +++--- tests/models/t5/test_modeling_tf_t5.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index a1a22d14b66c8..2eebdfd1cb60e 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -390,12 +390,12 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if not self.has_relative_attention_bias: position_bias = position_bias[:, :, -seq_length:, :] else: - # we might have a padded past structure, in which case we want to fetch the position bias slice for - # the most recently populated index + # we might have a padded past structure, in which case we want to fetch the position bias slice + # right after the most recently filled past index most_recently_filled_past_index = tf.reduce_max(tf.where(past_key_value[0][0, 0, :, 0] != 0.0)) position_bias = dynamic_slice( position_bias, - (0, 0, most_recently_filled_past_index, 0), + (0, 0, most_recently_filled_past_index + 1, 0), (1, self.n_heads, seq_length, real_seq_length), ) diff --git a/tests/models/t5/test_modeling_tf_t5.py b/tests/models/t5/test_modeling_tf_t5.py index b87e460c4d762..35f1d90886c98 100644 --- a/tests/models/t5/test_modeling_tf_t5.py +++ b/tests/models/t5/test_modeling_tf_t5.py @@ -599,7 +599,7 @@ def test_beam_search_xla_generate_simple(self): output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True) expected_output_string = [ - "Aujourd'hui, c'est une belle journée.", + "Aujourd'hui est une belle journée.", "J'ai quatre chats, trois chiens, deux oiseaux et un cheval.", ]