Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF: XLA beam search + most generation-compatible models are now also XLA-generate-compatible #17857

Merged
merged 9 commits into from Jun 29, 2022
249 changes: 198 additions & 51 deletions src/transformers/generation_tf_utils.py

Large diffs are not rendered by default.

64 changes: 0 additions & 64 deletions src/transformers/models/bart/modeling_tf_bart.py
Expand Up @@ -20,7 +20,6 @@

import numpy as np
import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
Expand Down Expand Up @@ -1434,69 +1433,6 @@ def prepare_inputs_for_generation(
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}

def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
# TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
# quite some duplicated code patterns it seems
past = outputs.past_key_values
is_past_initialized = model_kwargs.pop("past", None) is not None
decoder_attention_mask = model_kwargs.pop("decoder_attention_mask", None)
batch_size = past[0][0].shape[0]

if not is_past_initialized:
# past[0][0].shape[2] is seq_length of prompt
# The padded version of `past` requires only `max_length - 1` steps along the time dimension.
num_padding_values = max_length - past[0][0].shape[2] - 1
# prepare the padding tensor for `tf.pad`.
# `shape=(4, 2)` because each tensor element in `past` has `rank=4`.
# `indices=[[2, 1]]` means the time dimension (dim 2) needs **right**-padding (`1` means padding afterward).
padding_values = tf.scatter_nd(indices=[[2, 1]], updates=[num_padding_values], shape=(4, 2))

new_past = ()
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
new_past_layer[i] = tf.pad(past_layer[i], padding_values)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gante I was wondering why you decided to go for right padding as opposed to say left padding which could be simpler (no special treatment for relative positional embeddings, no dynamic_update_slicing required).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @aashiqmuhamed!

To be candid, we did not even consider other forms to pre-populate the fixed-shape tensors. Maybe left-padding would lead to faster code, as it would imply a concatenation on the right and a cropping on the left (perhaps faster than a scatter operation).

I'm sure there are many optimization opportunities in the TF XLA generate codebase -- for instance, beam search relies on many expensive reshapes, which is not necessary.

(dynamic_update_slicing is simply syntactic sugar, it can be replaced by a more verbose scatter operation)

new_past += (tuple(new_past_layer),)

# 1 one for decoder_start_token_id, Zeros for the currently-unfilled locations in the past tensor,
# ones for the actual input_ids
decoder_attention_mask = tf.concat(
[
tf.ones((batch_size, 1), dtype=tf.int32),
tf.zeros((batch_size, num_padding_values), dtype=tf.int32),
tf.ones((batch_size, 1), dtype=tf.int32),
],
axis=1,
)
else:
slice_start_base = tf.constant([0, 0, 1, 0])
decoder_attention_mask_update_slice = tf.ones((batch_size, 1), dtype=decoder_attention_mask.dtype)
# correct 5 here
new_past_index = current_pos - 1

new_past = ()
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
update_slice = past_layer[i][:, :, -1:]
# Write the last slice to the first open location in the padded past array
# and then truncate the last slice off the array
new_past_layer[i] = dynamic_update_slice(
past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index
)
new_past += (tuple(new_past_layer),)

update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
decoder_attention_mask = dynamic_update_slice(
decoder_attention_mask, decoder_attention_mask_update_slice, update_start
)

# set `decoder_attention_mask` and `past`
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
model_kwargs["past"] = new_past

return model_kwargs

def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/ctrl/modeling_tf_ctrl.py
Expand Up @@ -571,6 +571,8 @@ class TFCTRLLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size
# CTRL has numerical issues in XLA generate
self.supports_xla_generation = False
gante marked this conversation as resolved.
Show resolved Hide resolved

# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
Expand Down Expand Up @@ -613,6 +615,8 @@ def __init__(self, config, *inputs, **kwargs):
self.transformer = TFCTRLMainLayer(config, name="transformer")

self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
# CTRL has numerical issues in XLA generate
self.supports_xla_generation = False

def get_lm_head(self):
return self.lm_head
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/flaubert/modeling_tf_flaubert.py
Expand Up @@ -761,6 +761,8 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFFlaubertMainLayer(config, name="transformer")
self.pred_layer = TFFlaubertPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")
# Flaubert does not have past caching features
self.supports_xla_generation = False

def get_lm_head(self):
return self.pred_layer
Expand Down
58 changes: 0 additions & 58 deletions src/transformers/models/gpt2/modeling_tf_gpt2.py
Expand Up @@ -20,7 +20,6 @@

import numpy as np
import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
Expand Down Expand Up @@ -838,63 +837,6 @@ def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwa
"token_type_ids": token_type_ids,
}

def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
# TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
# quite some duplicated code patterns it seems
# also the `attention_mask` is currently used in a somewhat hacky to
# correctly influence the `past_key_values` - not sure if this is the way to go
# Let's keep that for a future PR.
past = outputs.past_key_values
is_past_initialized = model_kwargs.pop("past", None) is not None
attention_mask = model_kwargs.pop("attention_mask")
batch_size = attention_mask.shape[0]

if not is_past_initialized:
# past[0].shape[3] is seq_length of prompt
num_padding_values = max_length - past[0].shape[3] - 1

padding_values = np.zeros((5, 2), dtype=np.int32)
padding_values[3, 1] = num_padding_values
padding_values = tf.constant(padding_values)

new_past = list(past)
for i in range(len(past)):
new_past[i] = tf.pad(past[i], padding_values)

# Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids
attention_mask = tf.concat(
[
attention_mask,
tf.zeros((batch_size, num_padding_values), dtype=attention_mask.dtype),
tf.ones((batch_size, 1), dtype=attention_mask.dtype),
],
axis=1,
)
else:
new_past = [None for _ in range(len(past))]
slice_start_base = tf.constant([0, 0, 0, 1, 0])
attention_mask_update_slice = tf.ones((batch_size, 1), dtype=attention_mask.dtype)
# -1 because current_pos has already been incremented before this function
# -1 again because last index = len - 1
new_past_index = current_pos - 2

for i in range(len(past)):
update_slice = past[i][:, :, :, -1:]
# Write the last slice to the first open location in the padded past array
# and then truncate the last slice off the array
new_past[i] = dynamic_update_slice(
past[i][:, :, :, :-1], update_slice, slice_start_base * new_past_index
)

update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
attention_mask = dynamic_update_slice(attention_mask, attention_mask_update_slice, update_start)

# set `attention_mask` and `past`
model_kwargs["attention_mask"] = attention_mask
model_kwargs["past"] = tuple(new_past)

return model_kwargs

@unpack_inputs
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/gptj/modeling_tf_gptj.py
Expand Up @@ -722,6 +722,8 @@ def __init__(self, config, *inputs, **kwargs):
self.lm_head = tf.keras.layers.Dense(
config.vocab_size, kernel_initializer=get_initializer(config.initializer_range), name="lm_head"
)
# TODO (Joao): investigate why GPTJ has numerical issues in XLA generate
self.supports_xla_generation = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe open a "Good second issue" after this PR for GPT-J


def get_output_embeddings(self):
return self.lm_head
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/led/modeling_tf_led.py
Expand Up @@ -2334,6 +2334,8 @@ def __init__(self, config, *inputs, **kwargs):
self.final_logits_bias = self.add_weight(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
# TODO (Joao): investigate why LED has numerical issues in XLA generate
self.supports_xla_generation = False

def get_decoder(self):
return self.led.decoder
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/openai/modeling_tf_openai.py
Expand Up @@ -556,6 +556,8 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
# OpenAIGPT does not have past caching features
self.supports_xla_generation = False
gante marked this conversation as resolved.
Show resolved Hide resolved

def get_output_embeddings(self):
return self.get_input_embeddings()
Expand Down
Expand Up @@ -1332,6 +1332,8 @@ def __init__(self, config: Speech2TextConfig):
super().__init__(config)
self.model = TFSpeech2TextMainLayer(config, name="model")
self.lm_head = tf.keras.layers.Dense(self.config.vocab_size, use_bias=False, name="lm_head")
# TODO (Joao): investigate why Speech2Text has numerical issues in XLA generate
self.supports_xla_generation = False

def get_encoder(self):
return self.model.encoder
Expand Down
60 changes: 0 additions & 60 deletions src/transformers/models/t5/modeling_tf_t5.py
Expand Up @@ -23,7 +23,6 @@

import numpy as np
import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
Expand Down Expand Up @@ -1501,65 +1500,6 @@ def prepare_inputs_for_generation(
"use_cache": use_cache,
}

def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
# TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
# quite some duplicated code patterns it seems
past = outputs.past_key_values
is_past_initialized = model_kwargs.pop("past", None) is not None
decoder_attention_mask = model_kwargs.pop("decoder_attention_mask", None)
batch_size = past[0][0].shape[0]

if not is_past_initialized:
# past[0].shape[2] is seq_length of prompt
num_padding_values = max_length - past[0][0].shape[2] - 1
padding_values = tf.scatter_nd(indices=[[2, 1]], updates=[num_padding_values], shape=(4, 2))

new_past = ()
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
new_past_layer[i] = tf.pad(past_layer[i], padding_values)
new_past += (tuple(new_past_layer),)

# 1 one for decoder_start_token_id, Zeros for the currently-unfilled locations in the past tensor,
# ones for the actual input_ids
decoder_attention_mask = tf.concat(
[
tf.ones((batch_size, 1), dtype=tf.int32),
tf.zeros((batch_size, num_padding_values), dtype=tf.int32),
tf.ones((batch_size, 1), dtype=tf.int32),
],
axis=1,
)
else:
slice_start_base = tf.constant([0, 0, 1, 0])
decoder_attention_mask_update_slice = tf.ones((batch_size, 1), dtype=decoder_attention_mask.dtype)
# correct 5 here
new_past_index = current_pos - 1

new_past = ()
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
update_slice = past_layer[i][:, :, -1:]
# Write the last slice to the first open location in the padded past array
# and then truncate the last slice off the array
new_past_layer[i] = dynamic_update_slice(
past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index
)
new_past += (tuple(new_past_layer),)

update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
decoder_attention_mask = dynamic_update_slice(
decoder_attention_mask, decoder_attention_mask_update_slice, update_start
)

# set `decoder_attention_mask` and `past`
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
model_kwargs["past"] = new_past

return model_kwargs

def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return self._shift_right(labels)

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/xlm/modeling_tf_xlm.py
Expand Up @@ -797,6 +797,8 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFXLMMainLayer(config, name="transformer")
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")
# XLM does not have past caching features
self.supports_xla_generation = False

def get_lm_head(self):
return self.pred_layer
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/xlnet/modeling_tf_xlnet.py
Expand Up @@ -1192,6 +1192,8 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFXLNetMainLayer(config, name="transformer")
self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name="lm_loss")
# generate fails to convert to a graph with XLNet
self.supports_xla_generation = False

def get_lm_head(self):
return self.lm_loss
Expand Down
34 changes: 6 additions & 28 deletions tests/models/bart/test_modeling_tf_bart.py
Expand Up @@ -152,23 +152,6 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict):
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)

def create_and_check_bart_xla_generate_fast(self, config, input_ids, *args):
config.eos_token_id = None # Generate until max length
config.max_length = 10
config.do_sample = False
config.num_beams = 1
model = TFBartForConditionalGeneration(config=config)

# make sure there are no pad tokens in prompt
input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id - 1)

generated = model.generate(input_ids)

generate_xla = tf.function(model.generate, jit_compile=True)
generated_xla = generate_xla(input_ids)

self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist())


def prepare_bart_inputs_dict(
config,
Expand Down Expand Up @@ -310,10 +293,6 @@ def _get_word_embedding_weight(model, embedding_layer):
models_equal = False
self.assertTrue(models_equal)

def test_bart_model_xla_generate_fast(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.create_and_check_bart_xla_generate_fast(config, inputs["input_ids"])

def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
Expand Down Expand Up @@ -703,10 +682,8 @@ def test_xsum_1_1_generation(self):
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
assert result == EXPECTED

def test_xsum_1_1_xla_greedy_generation(self):
# TODO (Joao): this is temporary test, while XLA beam search is not operational. Move the XLA==non-XLA
# comparisons to the other tests after enabling XLA beam search.
# Note -- `no_repeat_ngram_size` has to be disabled, since it is not compatible with XLA
def test_xsum_1_1_xla_generation(self):
# same test as above, but with `no_repeat_ngram_size=0` (not compatible with XLA) and XLA comparison enabled
model = self.xsum_1_1_model
assert model.model.decoder.embed_tokens._layer == model.model.shared
ARTICLE = (
Expand Down Expand Up @@ -748,15 +725,16 @@ def test_xsum_1_1_xla_greedy_generation(self):
)
EXPECTED = (
" The International Criminal Court (ICC) has announced that it is to be investigated by the International"
" Criminal Court (ICC) over claims that the Palestinian genocide."
" Criminal Court (ICC) over allegations of war crimes."
)

dct = self.tok(ARTICLE, return_tensors="tf")
generated_ids = model.generate(**dct, num_beams=1, no_repeat_ngram_size=0)
generated_ids = model.generate(**dct, num_beams=4, no_repeat_ngram_size=0)
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
assert result == EXPECTED

xla_generate = tf.function(model.generate, jit_compile=True)
generated_ids = xla_generate(**dct, num_beams=1, no_repeat_ngram_size=0)
generated_ids = xla_generate(**dct, num_beams=4, no_repeat_ngram_size=0)
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
assert result == EXPECTED

Expand Down