Skip to content

Commit

Permalink
TF: GPT-J compatible with XLA generation (#17986)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Jul 6, 2022
1 parent bf37e5c commit 360719a
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 142 deletions.
90 changes: 41 additions & 49 deletions src/transformers/models/gptj/modeling_tf_gptj.py
Expand Up @@ -60,14 +60,12 @@
]


def fixed_pos_embedding(x: tf.Tensor, seq_dim: int = 1, seq_len: Optional[int] = None) -> Tuple[tf.Tensor, tf.Tensor]:
dim = shape_list(x)[-1]
if seq_len is None:
seq_len = shape_list(x)[seq_dim]
def create_sinusoidal_positions(num_pos: int, dim: int) -> tf.Tensor:
inv_freq = tf.cast(1.0 / (10000 ** (tf.range(0, dim, 2) / dim)), tf.float32)
seq_len_range = tf.cast(tf.range(seq_len), tf.float32)
sinusoid_inp = tf.cast(tf.einsum("i , j -> i j", seq_len_range, inv_freq), tf.float32)
return tf.cast(tf.sin(sinusoid_inp), dtype=x.dtype), tf.cast(tf.cos(sinusoid_inp), dtype=x.dtype)
sinusoid_inp = tf.cast(tf.einsum("i , j -> i j", tf.range(num_pos, dtype=tf.float32), inv_freq), tf.float32)
sin, cos = tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)
out = tf.concat((sin, cos), axis=1)
return out


def rotate_every_two(x: tf.Tensor) -> tf.Tensor:
Expand All @@ -77,11 +75,11 @@ def rotate_every_two(x: tf.Tensor) -> tf.Tensor:
return rotate_half_tensor


def apply_rotary_pos_emb(x: tf.Tensor, sincos: tf.Tensor, offset: int = 0) -> tf.Tensor:
def apply_rotary_pos_emb(tensor: tf.Tensor, sincos: tf.Tensor) -> tf.Tensor:
sin_pos, cos_pos = sincos
sin_pos = tf.repeat(sin_pos[None, offset : shape_list(x)[1] + offset, None, :], 2, 3)
cos_pos = tf.repeat(cos_pos[None, offset : shape_list(x)[1] + offset, None, :], 2, 3)
return (x * cos_pos) + (rotate_every_two(x) * sin_pos)
sin_pos = tf.repeat(sin_pos[:, :, None, :], 2, 3)
cos_pos = tf.repeat(cos_pos[:, :, None, :], 2, 3)
return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos)


class TFGPTJAttention(tf.keras.layers.Layer):
Expand Down Expand Up @@ -132,6 +130,8 @@ def __init__(self, config: GPTJConfig, **kwargs):
tf.cast(tf.experimental.numpy.tril(tf.ones((self.max_positions, self.max_positions))), tf.int8),
(1, 1, self.max_positions, self.max_positions),
)
pos_embd_dim = self.rotary_dim or self.embed_dim
self.embed_positions = create_sinusoidal_positions(self.max_positions, pos_embd_dim)

def get_causal_mask(self, key_length, query_length) -> tf.Tensor:
return tf.cast(self.lower_triangle_mask[:, :, key_length - query_length : key_length, :key_length], tf.bool)
Expand Down Expand Up @@ -207,8 +207,9 @@ def _attn(
def call(
self,
hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
layer_past: Optional[Tuple[tf.Tensor, tf.Tensor]] = None,
attention_mask: Optional[tf.Tensor] = None,
position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
Expand All @@ -221,30 +222,23 @@ def call(
key = self._split_heads(key, True)
value = self._split_heads(value, False)

seq_len = shape_list(key)[1]
offset = 0

if layer_past is not None:
offset = shape_list(layer_past[0])[-2]
seq_len += offset

sincos = tf.gather(self.embed_positions, position_ids, axis=0)
sincos = tf.split(sincos, 2, axis=-1)
if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim :]

q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]

sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
k_rot = apply_rotary_pos_emb(k_rot, sincos)
q_rot = apply_rotary_pos_emb(q_rot, sincos)

key = tf.concat((k_rot, k_pass), axis=-1)
query = tf.concat((q_rot, q_pass), axis=-1)
else:
sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
key = apply_rotary_pos_emb(key, sincos, offset=offset)
query = apply_rotary_pos_emb(query, sincos, offset=offset)
key = apply_rotary_pos_emb(key, sincos)
query = apply_rotary_pos_emb(query, sincos)

key = tf.transpose(key, (0, 2, 1, 3))
query = tf.transpose(query, (0, 2, 1, 3))
Expand Down Expand Up @@ -310,16 +304,18 @@ def call(
hidden_states: tf.Tensor,
layer_past: Optional[tf.Tensor] = None,
attention_mask: Optional[tf.Tensor] = None,
position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states,
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
Expand Down Expand Up @@ -466,12 +462,13 @@ def call(
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)

outputs = block(
hidden_states,
layer_past,
attention_mask,
head_mask[i],
use_cache,
output_attentions,
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
)

Expand Down Expand Up @@ -722,41 +719,36 @@ def __init__(self, config, *inputs, **kwargs):
self.lm_head = tf.keras.layers.Dense(
config.vocab_size, kernel_initializer=get_initializer(config.initializer_range), name="lm_head"
)
# TODO (Joao): investigate why GPTJ has numerical issues in XLA generate
self.supports_xla_generation = False

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs):
# TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
# tests will need to be fixed after the change

def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
if token_type_ids is not None:
token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)

position_ids = kwargs.get("position_ids", None)
attention_mask = kwargs.get("attention_mask", None)

# TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left
# for a future PR to not change too many things for now.
# All statements in this if case apply for both xla and non-xla (as they already do in PyTorch)
position_ids = None
attention_mask = None
if use_xla:
attention_mask = kwargs.get("attention_mask", None)
if past is not None and attention_mask is not None:
position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1
elif attention_mask is not None:
position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True)
if attention_mask is not None and position_ids is None:
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
if past:
position_ids = tf.expand_dims(position_ids[:, -1], -1)

return {
"input_ids": inputs,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past": past,
"use_cache": use_cache,
"token_type_ids": token_type_ids,
}

@unpack_inputs
Expand Down
137 changes: 44 additions & 93 deletions tests/models/gptj/test_modeling_tf_gptj.py
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import unittest

from transformers import AutoTokenizer, GPTJConfig, is_tf_available
Expand Down Expand Up @@ -48,6 +47,7 @@ def __init__(self, parent):
self.use_mc_token_ids = True
self.vocab_size = 99
self.hidden_size = 32
self.rotary_dim = 4
self.num_hidden_layers = 5
self.num_attention_heads = 4
self.intermediate_size = 37
Expand Down Expand Up @@ -103,6 +103,7 @@ def prepare_config_and_inputs(self):
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
rotary_dim=self.rotary_dim,
return_dict=True,
)

Expand Down Expand Up @@ -359,10 +360,10 @@ def test_resize_token_embeddings(self):


@require_tf
@tooslow
# Marked as @tooslow due to GPU OOM -- but still useful to run locally. Requires ~39GB of RAM.
class TFGPTJModelLanguageGenerationTest(unittest.TestCase):
@tooslow
def test_lm_generate_gptj(self):
# Marked as @tooslow due to GPU OOM
model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", from_pt=True)
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
# fmt: off
Expand All @@ -372,74 +373,20 @@ def test_lm_generate_gptj(self):
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)

@tooslow
def test_gptj_sample(self):
# Marked as @tooslow due to GPU OOM (issue #13676)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16")
model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", from_pt=True)

tf.random.set_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="tf", return_token_type_ids=True)
input_ids, token_type_ids = tokenized.input_ids, tokenized.token_type_ids
output_ids = model.generate(input_ids, do_sample=True)
tokenized = tokenizer("Today is a nice day and", return_tensors="tf")
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
output_ids = model.generate(**tokenized, do_sample=True, seed=[42, 0])
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)

output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5)
output_seq_tt = model.generate(
input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5
)
output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True)
output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True)

EXPECTED_OUTPUT_STR = "Today is a nice day and I am taking an hour to sit in the hammock and just enjoy"

EXPECTED_OUTPUT_STR = "Today is a nice day and I’m going to go for a walk. I’"
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
self.assertTrue(
all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))])
) # token_type_ids should change output

@slow
@unittest.skip(reason="TF generate currently has no time-based stopping criteria")
def test_gptj_sample_max_time(self):
tokenizer = AutoTokenizer.from_pretrained("anton-l/gpt-j-tiny-random")
model = TFGPTJForCausalLM.from_pretrained("anton-l/gpt-j-tiny-random", from_pt=True)

input_ids = tokenizer("Today is a nice day and", return_tensors="tf", return_token_type_ids=True).input_ids

MAX_TIME = 0.5

start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

@tooslow
def test_batch_generation(self):
# Marked as @tooslow due to GPU OOM
def _get_beam_search_test_objects(self):
model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", from_pt=True)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16")

Expand All @@ -454,42 +401,46 @@ def test_batch_generation(self):
"Hello, my dog is a little",
"Today, I",
]
expected_output_sentences = [
"Hello, my dog is a little over a year old and has been diagnosed with hip dysplasia",
"Today, I’m going to be talking about a topic that’",
]
return model, tokenizer, sentences, expected_output_sentences

inputs = tokenizer(sentences, return_tensors="tf", padding=True)
input_ids = inputs["input_ids"]
token_type_ids = tf.concat(
[
tf.zeros((input_ids.shape[0], input_ids.shape[1] - 1), dtype=tf.int64),
500 * tf.ones((input_ids.shape[0], 1), dtype=tf.int64),
],
axis=-1,
)
def test_batch_beam_search(self):
# Confirms that we get the expected results with left-padded beam search
model, tokenizer, sentences, expected_output_sentences = self._get_beam_search_test_objects()

outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"])
outputs_tt = model.generate(
input_ids=input_ids,
attention_mask=inputs["attention_mask"],
token_type_ids=token_type_ids,
)
inputs = tokenizer(sentences, return_tensors="tf", padding=True)
outputs = model.generate(**inputs, do_sample=False, num_beams=2)
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(expected_output_sentences, batch_out_sentence)

inputs_non_padded = tokenizer(sentences[0], return_tensors="tf").input_ids
output_non_padded = model.generate(input_ids=inputs_non_padded)
def test_batch_left_padding(self):
# Confirms that left-padding is working properly
model, tokenizer, sentences, expected_output_sentences = self._get_beam_search_test_objects()

inputs = tokenizer(sentences, return_tensors="tf", padding=True)
inputs_non_padded = tokenizer(sentences[0], return_tensors="tf")
output_non_padded = model.generate(**inputs_non_padded, do_sample=False, num_beams=2)
num_paddings = (
shape_list(inputs_non_padded)[-1] - tf.reduce_sum(tf.cast(inputs["attention_mask"][-1], tf.int64)).numpy()
shape_list(inputs_non_padded["input_ids"])[-1]
- tf.reduce_sum(tf.cast(inputs["attention_mask"][-1], tf.int64)).numpy()
)
inputs_padded = tokenizer(sentences[1], return_tensors="tf")
output_padded = model.generate(
**inputs_padded, do_sample=False, num_beams=2, max_length=model.config.max_length - num_paddings
)
inputs_padded = tokenizer(sentences[1], return_tensors="tf").input_ids
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)

batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True)
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
self.assertListEqual(expected_output_sentences, [non_padded_sentence, padded_sentence])

expected_output_sentence = [
"Hello, my dog is a little over a year old and has been diagnosed with a heart murmur",
"Today, I’m going to share with you a few of my favorite",
]
self.assertListEqual(expected_output_sentence, batch_out_sentence)
self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
def test_xla_beam_search(self):
# Confirms that XLA is working properly
model, tokenizer, sentences, expected_output_sentences = self._get_beam_search_test_objects()

inputs = tokenizer(sentences, return_tensors="tf", padding=True)
xla_generate = tf.function(model.generate, jit_compile=True)
outputs_xla = xla_generate(**inputs, do_sample=False, num_beams=2)
xla_sentence = tokenizer.batch_decode(outputs_xla, skip_special_tokens=True)
self.assertListEqual(expected_output_sentences, xla_sentence)

0 comments on commit 360719a

Please sign in to comment.