Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF: XLA beam search + most generation-compatible models are now also XLA-generate-compatible #17857

Merged
merged 9 commits into from Jun 29, 2022
38 changes: 12 additions & 26 deletions src/transformers/generation_tf_utils.py
Expand Up @@ -2484,9 +2484,6 @@ def beam_search(

def flatten_beam_dim(tensor, batch_axis=0):
"""Flattens the first two dimensions of a non-scalar array."""
# ignore scalars (e.g. cache index)
if tf.rank(tensor) == 0:
return tensor
Comment on lines -2487 to -2489
Copy link
Member Author

@gante gante Jun 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern was inherited from FLAX. Contrarily to FLAX, it seems, this if throws the XLA compiler off balance, causing tensor to have an unknown shape. All problems I was seeing before were a downstream consequence of that unknown shape.

return tf.reshape(
tensor,
tensor.shape[:batch_axis]
Expand All @@ -2496,9 +2493,6 @@ def flatten_beam_dim(tensor, batch_axis=0):

def unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=0):
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
# ignore scalars (e.g. cache index)
if tf.rank(tensor) == 0:
return tensor
return tf.reshape(
tensor, tensor.shape[:batch_axis] + [batch_size, num_beams] + tensor.shape[batch_axis + 1 :]
)
Expand All @@ -2507,27 +2501,19 @@ def gather_beams(nested, beam_indices, batch_axis=0):
"""Gathers the beam slices indexed by beam_indices into new beam array."""

def gather_fn(tensor):
# ignore scalars (e.g. cache index)
if tf.rank(tensor) == 0:
return tensor
else:
if batch_axis > 0:
# pushes all dimentions before the batch to the end, so we get (batch, beam_id, ...)
perm = [axis for axis in range(tf.rank(tensor)) if axis >= batch_axis] + list(
range(batch_axis)
)
tensor = tf.transpose(tensor, perm=perm)
if batch_axis > 0:
# pushes all dimentions before the batch to the end, so we get (batch, beam_id, ...)
perm = tf.concat((tf.range(tf.rank(tensor))[batch_axis:], tf.range(batch_axis)), axis=0)
tensor = tf.transpose(tensor, perm=perm)

gathered_tensor = tf.gather(params=tensor, indices=beam_indices, axis=1, batch_dims=1)
if batch_axis > 0:
# transposes back to the original dimensions
perm = [axis for axis in range(tf.rank(tensor)) if axis >= batch_axis] + list(
range(batch_axis)
)
perm = tf.math.invert_permutation(perm)
gathered_tensor = tf.transpose(gathered_tensor, perm=perm)
gathered_tensor = tf.gather(params=tensor, indices=beam_indices, axis=1, batch_dims=1)
if batch_axis > 0:
# transposes back to the original dimensions
perm = tf.concat((tf.range(tf.rank(tensor))[batch_axis:], tf.range(batch_axis)), axis=0)
perm = tf.math.invert_permutation(perm)
gathered_tensor = tf.transpose(gathered_tensor, perm=perm)

return gathered_tensor
return gathered_tensor

return tf.nest.map_structure(gather_fn, nested)

Expand Down Expand Up @@ -2734,7 +2720,7 @@ def beam_search_body_fn(
# - add length penalty
# - make sure no scores can be added anymore if beam is full
# - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs = topk_log_probs / (cur_len**length_penalty)
topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
beams_in_batch_are_full = (
tf.broadcast_to(
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), did_topk_just_finished.shape
Expand Down
24 changes: 24 additions & 0 deletions tests/models/gpt2/test_modeling_tf_gpt2.py
Expand Up @@ -627,3 +627,27 @@ def test_lm_generate_gpt2_sample_xla(self):
output_ids = xla_generate(**input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string_xla)

@slow
def test_lm_generate_gpt2_beam_search_xla(self):
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

sentences = ["The dog", "The flying machine"]
expected_output_strings = [
"The dog was found in the backyard of a home in the 6500 block of South Main Street",
"The flying machine is a very powerful machine, but it's not a very powerful machine. It's",
]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)

output_ids = model.generate(**input_ids, do_sample=False, num_beams=2)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_strings)

xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(**input_ids, do_sample=False, num_beams=2)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_strings)