New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TF: XLA beam search + most generation-compatible models are now also XLA-generate-compatible #17857
Conversation
# ignore scalars (e.g. cache index) | ||
if tf.rank(tensor) == 0: | ||
return tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pattern was inherited from FLAX. Contrarily to FLAX, it seems, this if
throws the XLA compiler off balance, causing tensor
to have an unknown shape. All problems I was seeing before were a downstream consequence of that unknown shape.
The documentation is not available anymore as the PR was closed or merged. |
Very cool! Can we try it out for at least on Encoder-Decoder architecture as well (just to know that this code holds true here)? |
@patrickvonplaten @Rocketknight1 now with encoder-decoder tests, and ready for review -- I was working on it on a separate branch, so I've merged it into this one. Now, this PR standardizes the XLA model kwargs preparation, and most models can use the XLA functionality. Some models were incompatible for different reasons, so there is a new flag to gate XLA generation (and the flag is set in the problematic architectures). Finally, I'm also considering adding a general test like |
@@ -722,6 +722,8 @@ def __init__(self, config, *inputs, **kwargs): | |||
self.lm_head = tf.keras.layers.Dense( | |||
config.vocab_size, kernel_initializer=get_initializer(config.initializer_range), name="lm_head" | |||
) | |||
# TODO (Joao): investigate why GPTJ has numerical issues in XLA generate | |||
self.supports_xla_generation = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's maybe open a "Good second issue" after this PR for GPT-J
xla_generate = tf.function(model.generate) | ||
|
||
# TODO (joao): there is something not quite right with XLA T5 -- as we increase `max_length` the two outputs | ||
# drift appart, where the XLA version clearly degrades its quality. XLA-related variables look fine (they are |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also open an issue for this one after having merged the PR - my intuition would be that there is a problem with the relative position embeddings if you don't have the same problem in Bart. Relative position embeddings are quite tricky and could easily have been messed up by XLA
xla_generate = tf.function(model.generate) | ||
|
||
# TODO (joao): there is something not quite right with XLA T5 -- as we increase `max_length` the two outputs | ||
# drift appart, where the XLA version clearly degrades its quality. XLA-related variables look fine (they are |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also open an issue for this one after having merged the PR - my intuition would be that there is a problem with the relative position embeddings if you don't have the same problem in Bart. Relative position embeddings are quite tricky and could easily have been messed up by XLA
tests/test_modeling_tf_common.py
Outdated
@@ -1600,6 +1600,51 @@ def test_dataset_conversion(self): | |||
model.compile(optimizer="sgd", run_eagerly=True) | |||
model.train_on_batch(test_batch, test_batch_labels) | |||
|
|||
def test_xla_generate_fast(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Do you know how much time this adds to the Circle CI?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good question. The bottleneck of the test is XLA compilation, which is a single-core operation, so it should take roughly the same time on all machines -- on my machine takes between 2 and 8 seconds per model, depending on the model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Let's merge it :-)
Can open issues right after for:
- T5 XLA degrading perf (think this is a subtle bug - if BART doesn't have this it's either numerical issues which would be a bit weird since T5 has been trained on TPU or maybe the relative position embeddings)
- open a bug for GPT-J (important model)
Note: as per the comment above, if this PR gets merged as it is, I will open an issue to track issues regarding XLA generation (relevant models failing fast tests, as well as models failing the slow tests) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this looks good, although I can't claim to understand the beam search code too well! I have a couple of questions, but if it's passing tests then I assume the problem is just my understanding rather than an actual bug.
# Write the last slice to the first open location in the padded past array | ||
# and then truncate the last slice off the array | ||
new_past_layer[i] = dynamic_update_slice( | ||
past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be the suggestion below? It feels like you're taking a slice of past_layer[i]
and then updating that slice (which will not necessarily update the original). Or does that just work out because it's happening in XLA?
past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index | |
past_layer[i], update_slice, slice_start_base * new_past_index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will doublecheck that 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My suspicion is that:
- It works (or tests would fail)
- Taking a slice does not make a copy of the underlying data
- Because you use the low-level XLA op
dynamic_update_slice
to write the new data, no copy-on-write happens, and no new tensor is created. Instead, you just update the slice in-place, which also updates the original tensor. This is normally extremely forbidden in TF.
If I'm right about this then we exist in a state of Tensorflow sin and should seek confession urgently once we merge the PR, and the code might break if the underlying implementation changes. But it's probably fine and we can leave it for now with a little TODO warning so people know what happened if it breaks 3 years from now, lol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TL;DR it is a safe operation 🙌
If we run
import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
a = tf.range(10)
print("a (before updates) ", a, id(a))
# a (before updates) tf.Tensor([0 1 2 3 4 5 6 7 8 9], shape=(10,), dtype=int32) 140127625611600
b = dynamic_update_slice(a[:-1], tf.constant([100]), tf.constant([5]))
print("a (after updates) ", a, id(a))
# a (after updates) tf.Tensor([0 1 2 3 4 5 6 7 8 9], shape=(10,), dtype=int32) 140127625611600
print("b ", b, id(b))
# b tf.Tensor([ 0 1 2 3 4 100 6 7 8], shape=(9,), dtype=int32) 140127625611984
# this shouldn't work, but it does
c = dynamic_update_slice(a[:-1], tf.constant([100]), tf.constant([20]))
print("c ", c, id(c))
# c tf.Tensor([ 0 1 2 3 4 5 6 7 100], shape=(9,), dtype=int32) 140127625612368
we can see that:
- the original tensor is not touched (praise be 🙏)
- the output tensor is a new tensor
- the
[:-1]
slices the input tensor, and the output has the shape of the sliced input tensor - indices out of bound are clipped (woot?)
# Write the last slice to the first open location in the padded past array | ||
# and then truncate the last slice off the array | ||
new_past[i] = dynamic_update_slice( | ||
past[i][:, :, :, :-1], update_slice, slice_start_base * new_past_index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here - it looks like a copied slice is being updated and that possibly only works because of XLA shenanigans!
past[i][:, :, :, :-1], update_slice, slice_start_base * new_past_index | |
past[i], update_slice, slice_start_base * new_past_index |
…XLA-generate-compatible (huggingface#17857) * working beam search 🎉 * XLA generation compatible with ALL classes * add xla generation slow test
for past_layer in past: | ||
new_past_layer = list(past_layer) | ||
for i in range(len(new_past_layer[:2])): | ||
new_past_layer[i] = tf.pad(past_layer[i], padding_values) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gante I was wondering why you decided to go for right padding as opposed to say left padding which could be simpler (no special treatment for relative positional embeddings, no dynamic_update_slicing required).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @aashiqmuhamed!
To be candid, we did not even consider other forms to pre-populate the fixed-shape tensors. Maybe left-padding would lead to faster code, as it would imply a concatenation on the right and a cropping on the left (perhaps faster than a scatter operation).
I'm sure there are many optimization opportunities in the TF XLA generate codebase -- for instance, beam search relies on many expensive reshapes, which is not necessary.
(dynamic_update_slicing
is simply syntactic sugar, it can be replaced by a more verbose scatter operation)
What does this PR do?
The much-awaited PR -- beam search is now XLA compatible. GPT2 is the only model with XLA beam search tests, more models will follow in subsequent PRs 🎊 Preliminary tests on my machine shows that XLA beam search on GPU is ~26x faster (greedy search and sample are ~30x faster).
Slow tests have been run for the usual generate models (gpt2, t5, rag, speech_to_text, encoder_decoder, vision_encoder_decoder, bart).
EDIT: I've also generalized a few functions, and now ALL models that are compatible with generate are also compatible with XLA generate (with a few exceptions, when the models have no past cache support)
A hard-earned lesson which is kinda obvious in hindsight:
if
branches can make the XLA compiler confused about variable shapes, tagging their shape as<unknown>
, which in turn causes all sorts of exceptions. Out of curiosity, I tried replacing theif
bytf.cond
, but the<unknown>
shape persisted (because the tensor could indeed have a different shape at tracing time, depending on the branch taken)