Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF: GPT-J compatible with XLA generation #17986

Merged
merged 3 commits into from Jul 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
90 changes: 41 additions & 49 deletions src/transformers/models/gptj/modeling_tf_gptj.py
Expand Up @@ -60,14 +60,12 @@
]


def fixed_pos_embedding(x: tf.Tensor, seq_dim: int = 1, seq_len: Optional[int] = None) -> Tuple[tf.Tensor, tf.Tensor]:
dim = shape_list(x)[-1]
if seq_len is None:
seq_len = shape_list(x)[seq_dim]
def create_sinusoidal_positions(num_pos: int, dim: int) -> tf.Tensor:
inv_freq = tf.cast(1.0 / (10000 ** (tf.range(0, dim, 2) / dim)), tf.float32)
seq_len_range = tf.cast(tf.range(seq_len), tf.float32)
sinusoid_inp = tf.cast(tf.einsum("i , j -> i j", seq_len_range, inv_freq), tf.float32)
return tf.cast(tf.sin(sinusoid_inp), dtype=x.dtype), tf.cast(tf.cos(sinusoid_inp), dtype=x.dtype)
sinusoid_inp = tf.cast(tf.einsum("i , j -> i j", tf.range(num_pos, dtype=tf.float32), inv_freq), tf.float32)
sin, cos = tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)
out = tf.concat((sin, cos), axis=1)
return out


def rotate_every_two(x: tf.Tensor) -> tf.Tensor:
Expand All @@ -77,11 +75,11 @@ def rotate_every_two(x: tf.Tensor) -> tf.Tensor:
return rotate_half_tensor


def apply_rotary_pos_emb(x: tf.Tensor, sincos: tf.Tensor, offset: int = 0) -> tf.Tensor:
def apply_rotary_pos_emb(tensor: tf.Tensor, sincos: tf.Tensor) -> tf.Tensor:
sin_pos, cos_pos = sincos
sin_pos = tf.repeat(sin_pos[None, offset : shape_list(x)[1] + offset, None, :], 2, 3)
cos_pos = tf.repeat(cos_pos[None, offset : shape_list(x)[1] + offset, None, :], 2, 3)
return (x * cos_pos) + (rotate_every_two(x) * sin_pos)
sin_pos = tf.repeat(sin_pos[:, :, None, :], 2, 3)
cos_pos = tf.repeat(cos_pos[:, :, None, :], 2, 3)
return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos)


class TFGPTJAttention(tf.keras.layers.Layer):
Expand Down Expand Up @@ -132,6 +130,8 @@ def __init__(self, config: GPTJConfig, **kwargs):
tf.cast(tf.experimental.numpy.tril(tf.ones((self.max_positions, self.max_positions))), tf.int8),
(1, 1, self.max_positions, self.max_positions),
)
pos_embd_dim = self.rotary_dim or self.embed_dim
self.embed_positions = create_sinusoidal_positions(self.max_positions, pos_embd_dim)

def get_causal_mask(self, key_length, query_length) -> tf.Tensor:
return tf.cast(self.lower_triangle_mask[:, :, key_length - query_length : key_length, :key_length], tf.bool)
Expand Down Expand Up @@ -207,8 +207,9 @@ def _attn(
def call(
self,
hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
layer_past: Optional[Tuple[tf.Tensor, tf.Tensor]] = None,
attention_mask: Optional[tf.Tensor] = None,
position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
Expand All @@ -221,30 +222,23 @@ def call(
key = self._split_heads(key, True)
value = self._split_heads(value, False)

seq_len = shape_list(key)[1]
offset = 0

if layer_past is not None:
offset = shape_list(layer_past[0])[-2]
seq_len += offset

sincos = tf.gather(self.embed_positions, position_ids, axis=0)
sincos = tf.split(sincos, 2, axis=-1)
if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim :]

q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]

sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
k_rot = apply_rotary_pos_emb(k_rot, sincos)
q_rot = apply_rotary_pos_emb(q_rot, sincos)

key = tf.concat((k_rot, k_pass), axis=-1)
query = tf.concat((q_rot, q_pass), axis=-1)
else:
sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
key = apply_rotary_pos_emb(key, sincos, offset=offset)
query = apply_rotary_pos_emb(query, sincos, offset=offset)
key = apply_rotary_pos_emb(key, sincos)
query = apply_rotary_pos_emb(query, sincos)

key = tf.transpose(key, (0, 2, 1, 3))
query = tf.transpose(query, (0, 2, 1, 3))
Expand Down Expand Up @@ -310,16 +304,18 @@ def call(
hidden_states: tf.Tensor,
layer_past: Optional[tf.Tensor] = None,
attention_mask: Optional[tf.Tensor] = None,
position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states,
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
Expand Down Expand Up @@ -466,12 +462,13 @@ def call(
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)

outputs = block(
hidden_states,
layer_past,
attention_mask,
head_mask[i],
use_cache,
output_attentions,
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
)

Expand Down Expand Up @@ -722,41 +719,36 @@ def __init__(self, config, *inputs, **kwargs):
self.lm_head = tf.keras.layers.Dense(
config.vocab_size, kernel_initializer=get_initializer(config.initializer_range), name="lm_head"
)
# TODO (Joao): investigate why GPTJ has numerical issues in XLA generate
self.supports_xla_generation = False

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs):
# TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
# tests will need to be fixed after the change

def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
if token_type_ids is not None:
token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)

position_ids = kwargs.get("position_ids", None)
attention_mask = kwargs.get("attention_mask", None)

# TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left
# for a future PR to not change too many things for now.
# All statements in this if case apply for both xla and non-xla (as they already do in PyTorch)
position_ids = None
attention_mask = None
if use_xla:
attention_mask = kwargs.get("attention_mask", None)
if past is not None and attention_mask is not None:
position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1
elif attention_mask is not None:
position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True)
if attention_mask is not None and position_ids is None:
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
if past:
position_ids = tf.expand_dims(position_ids[:, -1], -1)

return {
"input_ids": inputs,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past": past,
"use_cache": use_cache,
"token_type_ids": token_type_ids,
}

@unpack_inputs
Expand Down
137 changes: 44 additions & 93 deletions tests/models/gptj/test_modeling_tf_gptj.py
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import unittest

from transformers import AutoTokenizer, GPTJConfig, is_tf_available
Expand Down Expand Up @@ -48,6 +47,7 @@ def __init__(self, parent):
self.use_mc_token_ids = True
self.vocab_size = 99
self.hidden_size = 32
self.rotary_dim = 4
self.num_hidden_layers = 5
self.num_attention_heads = 4
self.intermediate_size = 37
Expand Down Expand Up @@ -103,6 +103,7 @@ def prepare_config_and_inputs(self):
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
rotary_dim=self.rotary_dim,
return_dict=True,
)

Expand Down Expand Up @@ -359,10 +360,10 @@ def test_resize_token_embeddings(self):


@require_tf
@tooslow
# Marked as @tooslow due to GPU OOM -- but still useful to run locally. Requires ~39GB of RAM.
class TFGPTJModelLanguageGenerationTest(unittest.TestCase):
@tooslow
def test_lm_generate_gptj(self):
# Marked as @tooslow due to GPU OOM
model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", from_pt=True)
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
# fmt: off
Expand All @@ -372,74 +373,20 @@ def test_lm_generate_gptj(self):
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)

@tooslow
def test_gptj_sample(self):
# Marked as @tooslow due to GPU OOM (issue #13676)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16")
model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", from_pt=True)

tf.random.set_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="tf", return_token_type_ids=True)
input_ids, token_type_ids = tokenized.input_ids, tokenized.token_type_ids
output_ids = model.generate(input_ids, do_sample=True)
tokenized = tokenizer("Today is a nice day and", return_tensors="tf")
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
output_ids = model.generate(**tokenized, do_sample=True, seed=[42, 0])
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)

output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5)
output_seq_tt = model.generate(
input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5
)
output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True)
output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True)

EXPECTED_OUTPUT_STR = "Today is a nice day and I am taking an hour to sit in the hammock and just enjoy"

EXPECTED_OUTPUT_STR = "Today is a nice day and I’m going to go for a walk. I’"
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
self.assertTrue(
all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))])
) # token_type_ids should change output

@slow
@unittest.skip(reason="TF generate currently has no time-based stopping criteria")
def test_gptj_sample_max_time(self):
tokenizer = AutoTokenizer.from_pretrained("anton-l/gpt-j-tiny-random")
model = TFGPTJForCausalLM.from_pretrained("anton-l/gpt-j-tiny-random", from_pt=True)

input_ids = tokenizer("Today is a nice day and", return_tensors="tf", return_token_type_ids=True).input_ids

MAX_TIME = 0.5

start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

@tooslow
def test_batch_generation(self):
# Marked as @tooslow due to GPU OOM
def _get_beam_search_test_objects(self):
model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", from_pt=True)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16")

Expand All @@ -454,42 +401,46 @@ def test_batch_generation(self):
"Hello, my dog is a little",
"Today, I",
]
expected_output_sentences = [
"Hello, my dog is a little over a year old and has been diagnosed with hip dysplasia",
"Today, I’m going to be talking about a topic that’",
]
return model, tokenizer, sentences, expected_output_sentences

inputs = tokenizer(sentences, return_tensors="tf", padding=True)
input_ids = inputs["input_ids"]
token_type_ids = tf.concat(
[
tf.zeros((input_ids.shape[0], input_ids.shape[1] - 1), dtype=tf.int64),
500 * tf.ones((input_ids.shape[0], 1), dtype=tf.int64),
],
axis=-1,
)
def test_batch_beam_search(self):
# Confirms that we get the expected results with left-padded beam search
model, tokenizer, sentences, expected_output_sentences = self._get_beam_search_test_objects()

outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"])
outputs_tt = model.generate(
input_ids=input_ids,
attention_mask=inputs["attention_mask"],
token_type_ids=token_type_ids,
)
inputs = tokenizer(sentences, return_tensors="tf", padding=True)
outputs = model.generate(**inputs, do_sample=False, num_beams=2)
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(expected_output_sentences, batch_out_sentence)

inputs_non_padded = tokenizer(sentences[0], return_tensors="tf").input_ids
output_non_padded = model.generate(input_ids=inputs_non_padded)
def test_batch_left_padding(self):
# Confirms that left-padding is working properly
model, tokenizer, sentences, expected_output_sentences = self._get_beam_search_test_objects()

inputs = tokenizer(sentences, return_tensors="tf", padding=True)
inputs_non_padded = tokenizer(sentences[0], return_tensors="tf")
output_non_padded = model.generate(**inputs_non_padded, do_sample=False, num_beams=2)
num_paddings = (
shape_list(inputs_non_padded)[-1] - tf.reduce_sum(tf.cast(inputs["attention_mask"][-1], tf.int64)).numpy()
shape_list(inputs_non_padded["input_ids"])[-1]
- tf.reduce_sum(tf.cast(inputs["attention_mask"][-1], tf.int64)).numpy()
)
inputs_padded = tokenizer(sentences[1], return_tensors="tf")
output_padded = model.generate(
**inputs_padded, do_sample=False, num_beams=2, max_length=model.config.max_length - num_paddings
)
inputs_padded = tokenizer(sentences[1], return_tensors="tf").input_ids
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)

batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True)
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
self.assertListEqual(expected_output_sentences, [non_padded_sentence, padded_sentence])

expected_output_sentence = [
"Hello, my dog is a little over a year old and has been diagnosed with a heart murmur",
"Today, I’m going to share with you a few of my favorite",
]
self.assertListEqual(expected_output_sentence, batch_out_sentence)
self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
def test_xla_beam_search(self):
# Confirms that XLA is working properly
model, tokenizer, sentences, expected_output_sentences = self._get_beam_search_test_objects()

inputs = tokenizer(sentences, return_tensors="tf", padding=True)
xla_generate = tf.function(model.generate, jit_compile=True)
outputs_xla = xla_generate(**inputs, do_sample=False, num_beams=2)
xla_sentence = tokenizer.batch_decode(outputs_xla, skip_special_tokens=True)
self.assertListEqual(expected_output_sentences, xla_sentence)