Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF: XLA beam search + most generation-compatible models are now also XLA-generate-compatible #17857

Merged
merged 9 commits into from Jun 29, 2022

Conversation

gante
Copy link
Member

@gante gante commented Jun 23, 2022

What does this PR do?

The much-awaited PR -- beam search is now XLA compatible. GPT2 is the only model with XLA beam search tests, more models will follow in subsequent PRs 🎊 Preliminary tests on my machine shows that XLA beam search on GPU is ~26x faster (greedy search and sample are ~30x faster).

Slow tests have been run for the usual generate models (gpt2, t5, rag, speech_to_text, encoder_decoder, vision_encoder_decoder, bart).

EDIT: I've also generalized a few functions, and now ALL models that are compatible with generate are also compatible with XLA generate (with a few exceptions, when the models have no past cache support)


A hard-earned lesson which is kinda obvious in hindsight: if branches can make the XLA compiler confused about variable shapes, tagging their shape as <unknown>, which in turn causes all sorts of exceptions. Out of curiosity, I tried replacing the if by tf.cond, but the <unknown> shape persisted (because the tensor could indeed have a different shape at tracing time, depending on the branch taken)

@gante gante mentioned this pull request Jun 23, 2022
Comment on lines -2487 to -2489
# ignore scalars (e.g. cache index)
if tf.rank(tensor) == 0:
return tensor
Copy link
Member Author

@gante gante Jun 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern was inherited from FLAX. Contrarily to FLAX, it seems, this if throws the XLA compiler off balance, causing tensor to have an unknown shape. All problems I was seeing before were a downstream consequence of that unknown shape.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 23, 2022

The documentation is not available anymore as the PR was closed or merged.

@gante gante marked this pull request as ready for review June 23, 2022 22:24
@patrickvonplaten
Copy link
Contributor

Very cool! Can we try it out for at least on Encoder-Decoder architecture as well (just to know that this code holds true here)?

@gante gante changed the title TF: XLA beam search TF: XLA beam search + all generation-compatible models are now also XLA-generate-compatible Jun 24, 2022
@gante gante changed the title TF: XLA beam search + all generation-compatible models are now also XLA-generate-compatible TF: XLA beam search + most generation-compatible models are now also XLA-generate-compatible Jun 24, 2022
@gante
Copy link
Member Author

gante commented Jun 24, 2022

@patrickvonplaten @Rocketknight1 now with encoder-decoder tests, and ready for review -- I was working on it on a separate branch, so I've merged it into this one. Now, this PR standardizes the XLA model kwargs preparation, and most models can use the XLA functionality. Some models were incompatible for different reasons, so there is a new flag to gate XLA generation (and the flag is set in the problematic architectures).

Finally, I'm also considering adding a general test like test_xla_generate_fast, but with @slow, beam search, and >100 tokens. It will probably break for a few models (like T5), but at least we would be able to automatically track which models are reliable with XLA beam search -- WDYT?

@@ -722,6 +722,8 @@ def __init__(self, config, *inputs, **kwargs):
self.lm_head = tf.keras.layers.Dense(
config.vocab_size, kernel_initializer=get_initializer(config.initializer_range), name="lm_head"
)
# TODO (Joao): investigate why GPTJ has numerical issues in XLA generate
self.supports_xla_generation = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe open a "Good second issue" after this PR for GPT-J

xla_generate = tf.function(model.generate)

# TODO (joao): there is something not quite right with XLA T5 -- as we increase `max_length` the two outputs
# drift appart, where the XLA version clearly degrades its quality. XLA-related variables look fine (they are
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also open an issue for this one after having merged the PR - my intuition would be that there is a problem with the relative position embeddings if you don't have the same problem in Bart. Relative position embeddings are quite tricky and could easily have been messed up by XLA

xla_generate = tf.function(model.generate)

# TODO (joao): there is something not quite right with XLA T5 -- as we increase `max_length` the two outputs
# drift appart, where the XLA version clearly degrades its quality. XLA-related variables look fine (they are
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also open an issue for this one after having merged the PR - my intuition would be that there is a problem with the relative position embeddings if you don't have the same problem in Bart. Relative position embeddings are quite tricky and could easily have been messed up by XLA

@@ -1600,6 +1600,51 @@ def test_dataset_conversion(self):
model.compile(optimizer="sgd", run_eagerly=True)
model.train_on_batch(test_batch, test_batch_labels)

def test_xla_generate_fast(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Do you know how much time this adds to the Circle CI?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good question. The bottleneck of the test is XLA compilation, which is a single-core operation, so it should take roughly the same time on all machines -- on my machine takes between 2 and 8 seconds per model, depending on the model

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Let's merge it :-)

Can open issues right after for:

  • T5 XLA degrading perf (think this is a subtle bug - if BART doesn't have this it's either numerical issues which would be a bit weird since T5 has been trained on TPU or maybe the relative position embeddings)
  • open a bug for GPT-J (important model)

@gante
Copy link
Member Author

gante commented Jun 27, 2022

Note: as per the comment above, if this PR gets merged as it is, I will open an issue to track issues regarding XLA generation (relevant models failing fast tests, as well as models failing the slow tests)

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks good, although I can't claim to understand the beam search code too well! I have a couple of questions, but if it's passing tests then I assume the problem is just my understanding rather than an actual bug.

# Write the last slice to the first open location in the padded past array
# and then truncate the last slice off the array
new_past_layer[i] = dynamic_update_slice(
past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be the suggestion below? It feels like you're taking a slice of past_layer[i] and then updating that slice (which will not necessarily update the original). Or does that just work out because it's happening in XLA?

Suggested change
past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index
past_layer[i], update_slice, slice_start_base * new_past_index

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will doublecheck that 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suspicion is that:

  1. It works (or tests would fail)
  2. Taking a slice does not make a copy of the underlying data
  3. Because you use the low-level XLA op dynamic_update_slice to write the new data, no copy-on-write happens, and no new tensor is created. Instead, you just update the slice in-place, which also updates the original tensor. This is normally extremely forbidden in TF.

If I'm right about this then we exist in a state of Tensorflow sin and should seek confession urgently once we merge the PR, and the code might break if the underlying implementation changes. But it's probably fine and we can leave it for now with a little TODO warning so people know what happened if it breaks 3 years from now, lol.

Copy link
Member Author

@gante gante Jun 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TL;DR it is a safe operation 🙌

If we run

import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

a = tf.range(10)
print("a (before updates) ", a, id(a))
# a (before updates)  tf.Tensor([0 1 2 3 4 5 6 7 8 9], shape=(10,), dtype=int32) 140127625611600

b = dynamic_update_slice(a[:-1], tf.constant([100]), tf.constant([5]))
print("a (after updates)  ", a, id(a))
# a (after updates)   tf.Tensor([0 1 2 3 4 5 6 7 8 9], shape=(10,), dtype=int32) 140127625611600
print("b ", b, id(b))
# b  tf.Tensor([  0   1   2   3   4 100   6   7   8], shape=(9,), dtype=int32) 140127625611984

# this shouldn't work, but it does
c = dynamic_update_slice(a[:-1], tf.constant([100]), tf.constant([20]))
print("c ", c, id(c))
# c  tf.Tensor([  0   1   2   3   4   5   6   7 100], shape=(9,), dtype=int32) 140127625612368

we can see that:

  1. the original tensor is not touched (praise be 🙏)
  2. the output tensor is a new tensor
  3. the [:-1] slices the input tensor, and the output has the shape of the sliced input tensor
  4. indices out of bound are clipped (woot?)

# Write the last slice to the first open location in the padded past array
# and then truncate the last slice off the array
new_past[i] = dynamic_update_slice(
past[i][:, :, :, :-1], update_slice, slice_start_base * new_past_index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here - it looks like a copied slice is being updated and that possibly only works because of XLA shenanigans!

Suggested change
past[i][:, :, :, :-1], update_slice, slice_start_base * new_past_index
past[i], update_slice, slice_start_base * new_past_index

@gante gante merged commit e6d27ca into huggingface:main Jun 29, 2022
@gante gante deleted the beam_search_4 branch June 29, 2022 11:41
viclzhu pushed a commit to viclzhu/transformers that referenced this pull request Jul 18, 2022
…XLA-generate-compatible (huggingface#17857)

* working beam search 🎉

* XLA generation compatible with ALL classes

* add xla generation slow test
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
new_past_layer[i] = tf.pad(past_layer[i], padding_values)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gante I was wondering why you decided to go for right padding as opposed to say left padding which could be simpler (no special treatment for relative positional embeddings, no dynamic_update_slicing required).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @aashiqmuhamed!

To be candid, we did not even consider other forms to pre-populate the fixed-shape tensors. Maybe left-padding would lead to faster code, as it would imply a concatenation on the right and a cropping on the left (perhaps faster than a scatter operation).

I'm sure there are many optimization opportunities in the TF XLA generate codebase -- for instance, beam search relies on many expensive reshapes, which is not necessary.

(dynamic_update_slicing is simply syntactic sugar, it can be replaced by a more verbose scatter operation)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants