Skip to content

Commit

Permalink
TF: XLA beam search + most generation-compatible models are now also …
Browse files Browse the repository at this point in the history
…XLA-generate-compatible (huggingface#17857)

* working beam search 🎉

* XLA generation compatible with ALL classes

* add xla generation slow test
  • Loading branch information
gante authored and viclzhu committed Jul 18, 2022
1 parent e4e0f8b commit 036baff
Show file tree
Hide file tree
Showing 16 changed files with 356 additions and 301 deletions.
249 changes: 198 additions & 51 deletions src/transformers/generation_tf_utils.py

Large diffs are not rendered by default.

64 changes: 0 additions & 64 deletions src/transformers/models/bart/modeling_tf_bart.py
Expand Up @@ -20,7 +20,6 @@

import numpy as np
import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
Expand Down Expand Up @@ -1434,69 +1433,6 @@ def prepare_inputs_for_generation(
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}

def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
# TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
# quite some duplicated code patterns it seems
past = outputs.past_key_values
is_past_initialized = model_kwargs.pop("past", None) is not None
decoder_attention_mask = model_kwargs.pop("decoder_attention_mask", None)
batch_size = past[0][0].shape[0]

if not is_past_initialized:
# past[0][0].shape[2] is seq_length of prompt
# The padded version of `past` requires only `max_length - 1` steps along the time dimension.
num_padding_values = max_length - past[0][0].shape[2] - 1
# prepare the padding tensor for `tf.pad`.
# `shape=(4, 2)` because each tensor element in `past` has `rank=4`.
# `indices=[[2, 1]]` means the time dimension (dim 2) needs **right**-padding (`1` means padding afterward).
padding_values = tf.scatter_nd(indices=[[2, 1]], updates=[num_padding_values], shape=(4, 2))

new_past = ()
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
new_past_layer[i] = tf.pad(past_layer[i], padding_values)
new_past += (tuple(new_past_layer),)

# 1 one for decoder_start_token_id, Zeros for the currently-unfilled locations in the past tensor,
# ones for the actual input_ids
decoder_attention_mask = tf.concat(
[
tf.ones((batch_size, 1), dtype=tf.int32),
tf.zeros((batch_size, num_padding_values), dtype=tf.int32),
tf.ones((batch_size, 1), dtype=tf.int32),
],
axis=1,
)
else:
slice_start_base = tf.constant([0, 0, 1, 0])
decoder_attention_mask_update_slice = tf.ones((batch_size, 1), dtype=decoder_attention_mask.dtype)
# correct 5 here
new_past_index = current_pos - 1

new_past = ()
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
update_slice = past_layer[i][:, :, -1:]
# Write the last slice to the first open location in the padded past array
# and then truncate the last slice off the array
new_past_layer[i] = dynamic_update_slice(
past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index
)
new_past += (tuple(new_past_layer),)

update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
decoder_attention_mask = dynamic_update_slice(
decoder_attention_mask, decoder_attention_mask_update_slice, update_start
)

# set `decoder_attention_mask` and `past`
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
model_kwargs["past"] = new_past

return model_kwargs

def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/ctrl/modeling_tf_ctrl.py
Expand Up @@ -571,6 +571,8 @@ class TFCTRLLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size
# CTRL has numerical issues in XLA generate
self.supports_xla_generation = False

# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
Expand Down Expand Up @@ -613,6 +615,8 @@ def __init__(self, config, *inputs, **kwargs):
self.transformer = TFCTRLMainLayer(config, name="transformer")

self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
# CTRL has numerical issues in XLA generate
self.supports_xla_generation = False

def get_lm_head(self):
return self.lm_head
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/flaubert/modeling_tf_flaubert.py
Expand Up @@ -761,6 +761,8 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFFlaubertMainLayer(config, name="transformer")
self.pred_layer = TFFlaubertPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")
# Flaubert does not have past caching features
self.supports_xla_generation = False

def get_lm_head(self):
return self.pred_layer
Expand Down
58 changes: 0 additions & 58 deletions src/transformers/models/gpt2/modeling_tf_gpt2.py
Expand Up @@ -20,7 +20,6 @@

import numpy as np
import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
Expand Down Expand Up @@ -838,63 +837,6 @@ def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwa
"token_type_ids": token_type_ids,
}

def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
# TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
# quite some duplicated code patterns it seems
# also the `attention_mask` is currently used in a somewhat hacky to
# correctly influence the `past_key_values` - not sure if this is the way to go
# Let's keep that for a future PR.
past = outputs.past_key_values
is_past_initialized = model_kwargs.pop("past", None) is not None
attention_mask = model_kwargs.pop("attention_mask")
batch_size = attention_mask.shape[0]

if not is_past_initialized:
# past[0].shape[3] is seq_length of prompt
num_padding_values = max_length - past[0].shape[3] - 1

padding_values = np.zeros((5, 2), dtype=np.int32)
padding_values[3, 1] = num_padding_values
padding_values = tf.constant(padding_values)

new_past = list(past)
for i in range(len(past)):
new_past[i] = tf.pad(past[i], padding_values)

# Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids
attention_mask = tf.concat(
[
attention_mask,
tf.zeros((batch_size, num_padding_values), dtype=attention_mask.dtype),
tf.ones((batch_size, 1), dtype=attention_mask.dtype),
],
axis=1,
)
else:
new_past = [None for _ in range(len(past))]
slice_start_base = tf.constant([0, 0, 0, 1, 0])
attention_mask_update_slice = tf.ones((batch_size, 1), dtype=attention_mask.dtype)
# -1 because current_pos has already been incremented before this function
# -1 again because last index = len - 1
new_past_index = current_pos - 2

for i in range(len(past)):
update_slice = past[i][:, :, :, -1:]
# Write the last slice to the first open location in the padded past array
# and then truncate the last slice off the array
new_past[i] = dynamic_update_slice(
past[i][:, :, :, :-1], update_slice, slice_start_base * new_past_index
)

update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
attention_mask = dynamic_update_slice(attention_mask, attention_mask_update_slice, update_start)

# set `attention_mask` and `past`
model_kwargs["attention_mask"] = attention_mask
model_kwargs["past"] = tuple(new_past)

return model_kwargs

@unpack_inputs
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/gptj/modeling_tf_gptj.py
Expand Up @@ -722,6 +722,8 @@ def __init__(self, config, *inputs, **kwargs):
self.lm_head = tf.keras.layers.Dense(
config.vocab_size, kernel_initializer=get_initializer(config.initializer_range), name="lm_head"
)
# TODO (Joao): investigate why GPTJ has numerical issues in XLA generate
self.supports_xla_generation = False

def get_output_embeddings(self):
return self.lm_head
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/led/modeling_tf_led.py
Expand Up @@ -2334,6 +2334,8 @@ def __init__(self, config, *inputs, **kwargs):
self.final_logits_bias = self.add_weight(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
# TODO (Joao): investigate why LED has numerical issues in XLA generate
self.supports_xla_generation = False

def get_decoder(self):
return self.led.decoder
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/openai/modeling_tf_openai.py
Expand Up @@ -556,6 +556,8 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
# OpenAIGPT does not have past caching features
self.supports_xla_generation = False

def get_output_embeddings(self):
return self.get_input_embeddings()
Expand Down
Expand Up @@ -1332,6 +1332,8 @@ def __init__(self, config: Speech2TextConfig):
super().__init__(config)
self.model = TFSpeech2TextMainLayer(config, name="model")
self.lm_head = tf.keras.layers.Dense(self.config.vocab_size, use_bias=False, name="lm_head")
# TODO (Joao): investigate why Speech2Text has numerical issues in XLA generate
self.supports_xla_generation = False

def get_encoder(self):
return self.model.encoder
Expand Down
60 changes: 0 additions & 60 deletions src/transformers/models/t5/modeling_tf_t5.py
Expand Up @@ -23,7 +23,6 @@

import numpy as np
import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
Expand Down Expand Up @@ -1501,65 +1500,6 @@ def prepare_inputs_for_generation(
"use_cache": use_cache,
}

def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
# TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
# quite some duplicated code patterns it seems
past = outputs.past_key_values
is_past_initialized = model_kwargs.pop("past", None) is not None
decoder_attention_mask = model_kwargs.pop("decoder_attention_mask", None)
batch_size = past[0][0].shape[0]

if not is_past_initialized:
# past[0].shape[2] is seq_length of prompt
num_padding_values = max_length - past[0][0].shape[2] - 1
padding_values = tf.scatter_nd(indices=[[2, 1]], updates=[num_padding_values], shape=(4, 2))

new_past = ()
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
new_past_layer[i] = tf.pad(past_layer[i], padding_values)
new_past += (tuple(new_past_layer),)

# 1 one for decoder_start_token_id, Zeros for the currently-unfilled locations in the past tensor,
# ones for the actual input_ids
decoder_attention_mask = tf.concat(
[
tf.ones((batch_size, 1), dtype=tf.int32),
tf.zeros((batch_size, num_padding_values), dtype=tf.int32),
tf.ones((batch_size, 1), dtype=tf.int32),
],
axis=1,
)
else:
slice_start_base = tf.constant([0, 0, 1, 0])
decoder_attention_mask_update_slice = tf.ones((batch_size, 1), dtype=decoder_attention_mask.dtype)
# correct 5 here
new_past_index = current_pos - 1

new_past = ()
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
update_slice = past_layer[i][:, :, -1:]
# Write the last slice to the first open location in the padded past array
# and then truncate the last slice off the array
new_past_layer[i] = dynamic_update_slice(
past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index
)
new_past += (tuple(new_past_layer),)

update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
decoder_attention_mask = dynamic_update_slice(
decoder_attention_mask, decoder_attention_mask_update_slice, update_start
)

# set `decoder_attention_mask` and `past`
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
model_kwargs["past"] = new_past

return model_kwargs

def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return self._shift_right(labels)

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/xlm/modeling_tf_xlm.py
Expand Up @@ -797,6 +797,8 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFXLMMainLayer(config, name="transformer")
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")
# XLM does not have past caching features
self.supports_xla_generation = False

def get_lm_head(self):
return self.pred_layer
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/xlnet/modeling_tf_xlnet.py
Expand Up @@ -1192,6 +1192,8 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFXLNetMainLayer(config, name="transformer")
self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name="lm_loss")
# generate fails to convert to a graph with XLNet
self.supports_xla_generation = False

def get_lm_head(self):
return self.lm_loss
Expand Down
34 changes: 6 additions & 28 deletions tests/models/bart/test_modeling_tf_bart.py
Expand Up @@ -152,23 +152,6 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict):
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)

def create_and_check_bart_xla_generate_fast(self, config, input_ids, *args):
config.eos_token_id = None # Generate until max length
config.max_length = 10
config.do_sample = False
config.num_beams = 1
model = TFBartForConditionalGeneration(config=config)

# make sure there are no pad tokens in prompt
input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id - 1)

generated = model.generate(input_ids)

generate_xla = tf.function(model.generate, jit_compile=True)
generated_xla = generate_xla(input_ids)

self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist())


def prepare_bart_inputs_dict(
config,
Expand Down Expand Up @@ -310,10 +293,6 @@ def _get_word_embedding_weight(model, embedding_layer):
models_equal = False
self.assertTrue(models_equal)

def test_bart_model_xla_generate_fast(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.create_and_check_bart_xla_generate_fast(config, inputs["input_ids"])

def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
Expand Down Expand Up @@ -703,10 +682,8 @@ def test_xsum_1_1_generation(self):
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
assert result == EXPECTED

def test_xsum_1_1_xla_greedy_generation(self):
# TODO (Joao): this is temporary test, while XLA beam search is not operational. Move the XLA==non-XLA
# comparisons to the other tests after enabling XLA beam search.
# Note -- `no_repeat_ngram_size` has to be disabled, since it is not compatible with XLA
def test_xsum_1_1_xla_generation(self):
# same test as above, but with `no_repeat_ngram_size=0` (not compatible with XLA) and XLA comparison enabled
model = self.xsum_1_1_model
assert model.model.decoder.embed_tokens._layer == model.model.shared
ARTICLE = (
Expand Down Expand Up @@ -748,15 +725,16 @@ def test_xsum_1_1_xla_greedy_generation(self):
)
EXPECTED = (
" The International Criminal Court (ICC) has announced that it is to be investigated by the International"
" Criminal Court (ICC) over claims that the Palestinian genocide."
" Criminal Court (ICC) over allegations of war crimes."
)

dct = self.tok(ARTICLE, return_tensors="tf")
generated_ids = model.generate(**dct, num_beams=1, no_repeat_ngram_size=0)
generated_ids = model.generate(**dct, num_beams=4, no_repeat_ngram_size=0)
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
assert result == EXPECTED

xla_generate = tf.function(model.generate, jit_compile=True)
generated_ids = xla_generate(**dct, num_beams=1, no_repeat_ngram_size=0)
generated_ids = xla_generate(**dct, num_beams=4, no_repeat_ngram_size=0)
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
assert result == EXPECTED

Expand Down

0 comments on commit 036baff

Please sign in to comment.