From 8ca86ba8bb69a15ec57e43ae5d926025ff801b1c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 1 Jul 2022 16:42:01 +0000 Subject: [PATCH 1/3] tmp commit --- .../models/gptj/modeling_tf_gptj.py | 66 +++++---- tests/models/gptj/test_modeling_tf_gptj.py | 132 ++++++------------ 2 files changed, 79 insertions(+), 119 deletions(-) diff --git a/src/transformers/models/gptj/modeling_tf_gptj.py b/src/transformers/models/gptj/modeling_tf_gptj.py index efce9e7086bfe..94c5861efefd6 100644 --- a/src/transformers/models/gptj/modeling_tf_gptj.py +++ b/src/transformers/models/gptj/modeling_tf_gptj.py @@ -70,6 +70,14 @@ def fixed_pos_embedding(x: tf.Tensor, seq_dim: int = 1, seq_len: Optional[int] = return tf.cast(tf.sin(sinusoid_inp), dtype=x.dtype), tf.cast(tf.cos(sinusoid_inp), dtype=x.dtype) +def create_sinusoidal_positions(num_pos: int, dim: int) -> tf.Tensor: + inv_freq = tf.cast(1.0 / (10000 ** (tf.range(0, dim, 2) / dim)), tf.float32) + sinusoid_inp = tf.cast(tf.einsum("i , j -> i j", tf.range(num_pos, dtype=tf.float32), inv_freq), tf.float32) + sin, cos = tf.sin(sinusoid_inp), tf.cos(sinusoid_inp) + out = tf.concat((sin, cos), axis=1) + return out + + def rotate_every_two(x: tf.Tensor) -> tf.Tensor: rotate_half_tensor = tf.stack((-x[:, :, :, 1::2], x[:, :, :, ::2]), axis=-1) new_shape = shape_list(rotate_half_tensor)[:-2] + [tf.math.reduce_prod(shape_list(rotate_half_tensor)[-2:])] @@ -132,6 +140,8 @@ def __init__(self, config: GPTJConfig, **kwargs): tf.cast(tf.experimental.numpy.tril(tf.ones((self.max_positions, self.max_positions))), tf.int8), (1, 1, self.max_positions, self.max_positions), ) + pos_embd_dim = self.rotary_dim or self.embed_dim + self.embed_positions = create_sinusoidal_positions(self.max_positions, pos_embd_dim) def get_causal_mask(self, key_length, query_length) -> tf.Tensor: return tf.cast(self.lower_triangle_mask[:, :, key_length - query_length : key_length, :key_length], tf.bool) @@ -207,8 +217,9 @@ def _attn( def call( self, hidden_states: tf.Tensor, - attention_mask: Optional[tf.Tensor] = None, layer_past: Optional[Tuple[tf.Tensor, tf.Tensor]] = None, + attention_mask: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, head_mask: Optional[tf.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, @@ -228,6 +239,9 @@ def call( offset = shape_list(layer_past[0])[-2] seq_len += offset + breakpoint() + sincos = tf.gather(self.embed_positions, position_ids, axis=0) + sincos = tf.split(sincos, 2, axis=-1) if self.rotary_dim is not None: k_rot = key[:, :, :, : self.rotary_dim] k_pass = key[:, :, :, self.rotary_dim :] @@ -235,14 +249,14 @@ def call( q_rot = query[:, :, :, : self.rotary_dim] q_pass = query[:, :, :, self.rotary_dim :] - sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) + # sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) key = tf.concat((k_rot, k_pass), axis=-1) query = tf.concat((q_rot, q_pass), axis=-1) else: - sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) + # sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) key = apply_rotary_pos_emb(key, sincos, offset=offset) query = apply_rotary_pos_emb(query, sincos, offset=offset) @@ -310,6 +324,7 @@ def call( hidden_states: tf.Tensor, layer_past: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, head_mask: Optional[tf.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, @@ -317,9 +332,10 @@ def call( residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_outputs = self.attn( - hidden_states, + hidden_states=hidden_states, layer_past=layer_past, attention_mask=attention_mask, + position_ids=position_ids, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, @@ -466,12 +482,13 @@ def call( all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) outputs = block( - hidden_states, - layer_past, - attention_mask, - head_mask[i], - use_cache, - output_attentions, + hidden_states=hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, training=training, ) @@ -722,8 +739,6 @@ def __init__(self, config, *inputs, **kwargs): self.lm_head = tf.keras.layers.Dense( config.vocab_size, kernel_initializer=get_initializer(config.initializer_range), name="lm_head" ) - # TODO (Joao): investigate why GPTJ has numerical issues in XLA generate - self.supports_xla_generation = False def get_output_embeddings(self): return self.lm_head @@ -731,25 +746,21 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs): - # TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2 - # tests will need to be fixed after the change - + def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past: inputs = tf.expand_dims(inputs[:, -1], -1) + if token_type_ids is not None: + token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1) + + position_ids = kwargs.get("position_ids", None) + attention_mask = kwargs.get("attention_mask", None) - # TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left - # for a future PR to not change too many things for now. - # All statements in this if case apply for both xla and non-xla (as they already do in PyTorch) - position_ids = None - attention_mask = None - if use_xla: - attention_mask = kwargs.get("attention_mask", None) - if past is not None and attention_mask is not None: - position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1 - elif attention_mask is not None: - position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True) + if attention_mask is not None and position_ids is None: + position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) + if past: + position_ids = tf.expand_dims(position_ids[:, -1], -1) return { "input_ids": inputs, @@ -757,6 +768,7 @@ def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_x "position_ids": position_ids, "past": past, "use_cache": use_cache, + "token_type_ids": token_type_ids, } @unpack_inputs diff --git a/tests/models/gptj/test_modeling_tf_gptj.py b/tests/models/gptj/test_modeling_tf_gptj.py index 0d9af0b65087a..7ae7de3b672ec 100644 --- a/tests/models/gptj/test_modeling_tf_gptj.py +++ b/tests/models/gptj/test_modeling_tf_gptj.py @@ -359,10 +359,11 @@ def test_resize_token_embeddings(self): @require_tf +# @tooslow +# Marked as @tooslow due to GPU OOM -- but still useful to run locally. Requires ~39GB of RAM. class TFGPTJModelLanguageGenerationTest(unittest.TestCase): - @tooslow + def test_lm_generate_gptj(self): - # Marked as @tooslow due to GPU OOM model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", from_pt=True) input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog # fmt: off @@ -372,74 +373,20 @@ def test_lm_generate_gptj(self): output_ids = model.generate(input_ids, do_sample=False) self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) - @tooslow def test_gptj_sample(self): - # Marked as @tooslow due to GPU OOM (issue #13676) tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16") model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", from_pt=True) - tf.random.set_seed(0) - tokenized = tokenizer("Today is a nice day and", return_tensors="tf", return_token_type_ids=True) - input_ids, token_type_ids = tokenized.input_ids, tokenized.token_type_ids - output_ids = model.generate(input_ids, do_sample=True) + tokenized = tokenizer("Today is a nice day and", return_tensors="tf") + # forces the generation to happen on CPU, to avoid GPU-related quirks + with tf.device(":/CPU:0"): + output_ids = model.generate(**tokenized, do_sample=True, seed=[42, 0]) output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) - output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5) - output_seq_tt = model.generate( - input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5 - ) - output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True) - output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True) - - EXPECTED_OUTPUT_STR = "Today is a nice day and I am taking an hour to sit in the hammock and just enjoy" - + EXPECTED_OUTPUT_STR = "Today is a nice day and I’m going to go for a walk. I’" self.assertEqual(output_str, EXPECTED_OUTPUT_STR) - self.assertTrue( - all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))]) - ) # token_type_ids should change output - @slow - @unittest.skip(reason="TF generate currently has no time-based stopping criteria") - def test_gptj_sample_max_time(self): - tokenizer = AutoTokenizer.from_pretrained("anton-l/gpt-j-tiny-random") - model = TFGPTJForCausalLM.from_pretrained("anton-l/gpt-j-tiny-random", from_pt=True) - - input_ids = tokenizer("Today is a nice day and", return_tensors="tf", return_token_type_ids=True).input_ids - - MAX_TIME = 0.5 - - start = datetime.datetime.now() - model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256) - duration = datetime.datetime.now() - start - self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) - self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) - - start = datetime.datetime.now() - model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256) - duration = datetime.datetime.now() - start - self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) - self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) - - start = datetime.datetime.now() - model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256) - duration = datetime.datetime.now() - start - self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) - self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) - - start = datetime.datetime.now() - model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256) - duration = datetime.datetime.now() - start - self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) - self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) - - start = datetime.datetime.now() - model.generate(input_ids, do_sample=False, max_time=None, max_length=256) - duration = datetime.datetime.now() - start - self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) - - @tooslow - def test_batch_generation(self): - # Marked as @tooslow due to GPU OOM + def _get_beam_search_test_objects(self): model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", from_pt=True) tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16") @@ -454,42 +401,43 @@ def test_batch_generation(self): "Hello, my dog is a little", "Today, I", ] + expected_output_sentences = [ + "Hello, my dog is a little over a year old and has been diagnosed with hip dysplasia", + "Today, I’m going to be talking about a topic that’", + ] + return model, tokenizer, sentences, expected_output_sentences - inputs = tokenizer(sentences, return_tensors="tf", padding=True) - input_ids = inputs["input_ids"] - token_type_ids = tf.concat( - [ - tf.zeros((input_ids.shape[0], input_ids.shape[1] - 1), dtype=tf.int64), - 500 * tf.ones((input_ids.shape[0], 1), dtype=tf.int64), - ], - axis=-1, - ) + def test_batch_beam_search(self): + # Confirms that we get the expected results with left-padded beam search + model, tokenizer, sentences, expected_output_sentences = self._get_beam_search_test_objects() - outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"]) - outputs_tt = model.generate( - input_ids=input_ids, - attention_mask=inputs["attention_mask"], - token_type_ids=token_type_ids, - ) + inputs = tokenizer(sentences, return_tensors="tf", padding=True) + outputs = model.generate(**inputs, do_sample=False, num_beams=2) + batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) + self.assertListEqual(expected_output_sentences, batch_out_sentence) - inputs_non_padded = tokenizer(sentences[0], return_tensors="tf").input_ids - output_non_padded = model.generate(input_ids=inputs_non_padded) + def test_batch_left_padding(self): + # Confirms that left-padding is working properly + model, tokenizer, sentences, expected_output_sentences = self._get_beam_search_test_objects() + inputs = tokenizer(sentences, return_tensors="tf", padding=True) + inputs_non_padded = tokenizer(sentences[0], return_tensors="tf") + output_non_padded = model.generate(**inputs_non_padded, do_sample=False, num_beams=2) num_paddings = ( - shape_list(inputs_non_padded)[-1] - tf.reduce_sum(tf.cast(inputs["attention_mask"][-1], tf.int64)).numpy() + shape_list(inputs_non_padded["input_ids"])[-1] - tf.reduce_sum(tf.cast(inputs["attention_mask"][-1], tf.int64)).numpy() ) - inputs_padded = tokenizer(sentences[1], return_tensors="tf").input_ids - output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) - - batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) - batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True) + inputs_padded = tokenizer(sentences[1], return_tensors="tf") + output_padded = model.generate(**inputs_padded, do_sample=False, num_beams=2, max_length=model.config.max_length - num_paddings) non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True) padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True) + self.assertListEqual(expected_output_sentences, [non_padded_sentence, padded_sentence]) - expected_output_sentence = [ - "Hello, my dog is a little over a year old and has been diagnosed with a heart murmur", - "Today, I’m going to share with you a few of my favorite", - ] - self.assertListEqual(expected_output_sentence, batch_out_sentence) - self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output - self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence]) + def test_xla_beam_search(self): + # Confirms that XLA is working properly + model, tokenizer, sentences, expected_output_sentences = self._get_beam_search_test_objects() + + inputs = tokenizer(sentences, return_tensors="tf", padding=True) + xla_generate = tf.function(model.generate, jit_compile=True) + outputs_xla = xla_generate(**inputs, do_sample=False, num_beams=2) + xla_sentence = tokenizer.batch_decode(outputs_xla, skip_special_tokens=True) + self.assertListEqual(expected_output_sentences, xla_sentence) From 9c99ec95eb90d2ec82aff78fb36104c98595b4fd Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 1 Jul 2022 17:00:47 +0000 Subject: [PATCH 2/3] XLA GPT-J --- .../models/gptj/modeling_tf_gptj.py | 36 +++++-------------- tests/models/gptj/test_modeling_tf_gptj.py | 11 +++--- 2 files changed, 14 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/gptj/modeling_tf_gptj.py b/src/transformers/models/gptj/modeling_tf_gptj.py index 94c5861efefd6..ae0c83fae9b89 100644 --- a/src/transformers/models/gptj/modeling_tf_gptj.py +++ b/src/transformers/models/gptj/modeling_tf_gptj.py @@ -60,16 +60,6 @@ ] -def fixed_pos_embedding(x: tf.Tensor, seq_dim: int = 1, seq_len: Optional[int] = None) -> Tuple[tf.Tensor, tf.Tensor]: - dim = shape_list(x)[-1] - if seq_len is None: - seq_len = shape_list(x)[seq_dim] - inv_freq = tf.cast(1.0 / (10000 ** (tf.range(0, dim, 2) / dim)), tf.float32) - seq_len_range = tf.cast(tf.range(seq_len), tf.float32) - sinusoid_inp = tf.cast(tf.einsum("i , j -> i j", seq_len_range, inv_freq), tf.float32) - return tf.cast(tf.sin(sinusoid_inp), dtype=x.dtype), tf.cast(tf.cos(sinusoid_inp), dtype=x.dtype) - - def create_sinusoidal_positions(num_pos: int, dim: int) -> tf.Tensor: inv_freq = tf.cast(1.0 / (10000 ** (tf.range(0, dim, 2) / dim)), tf.float32) sinusoid_inp = tf.cast(tf.einsum("i , j -> i j", tf.range(num_pos, dtype=tf.float32), inv_freq), tf.float32) @@ -85,11 +75,11 @@ def rotate_every_two(x: tf.Tensor) -> tf.Tensor: return rotate_half_tensor -def apply_rotary_pos_emb(x: tf.Tensor, sincos: tf.Tensor, offset: int = 0) -> tf.Tensor: +def apply_rotary_pos_emb(tensor: tf.Tensor, sincos: tf.Tensor) -> tf.Tensor: sin_pos, cos_pos = sincos - sin_pos = tf.repeat(sin_pos[None, offset : shape_list(x)[1] + offset, None, :], 2, 3) - cos_pos = tf.repeat(cos_pos[None, offset : shape_list(x)[1] + offset, None, :], 2, 3) - return (x * cos_pos) + (rotate_every_two(x) * sin_pos) + sin_pos = tf.repeat(sin_pos[:, :, None, :], 2, 3) + cos_pos = tf.repeat(cos_pos[:, :, None, :], 2, 3) + return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos) class TFGPTJAttention(tf.keras.layers.Layer): @@ -232,14 +222,6 @@ def call( key = self._split_heads(key, True) value = self._split_heads(value, False) - seq_len = shape_list(key)[1] - offset = 0 - - if layer_past is not None: - offset = shape_list(layer_past[0])[-2] - seq_len += offset - - breakpoint() sincos = tf.gather(self.embed_positions, position_ids, axis=0) sincos = tf.split(sincos, 2, axis=-1) if self.rotary_dim is not None: @@ -249,16 +231,14 @@ def call( q_rot = query[:, :, :, : self.rotary_dim] q_pass = query[:, :, :, self.rotary_dim :] - # sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) - k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) - q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) + k_rot = apply_rotary_pos_emb(k_rot, sincos) + q_rot = apply_rotary_pos_emb(q_rot, sincos) key = tf.concat((k_rot, k_pass), axis=-1) query = tf.concat((q_rot, q_pass), axis=-1) else: - # sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) - key = apply_rotary_pos_emb(key, sincos, offset=offset) - query = apply_rotary_pos_emb(query, sincos, offset=offset) + key = apply_rotary_pos_emb(key, sincos) + query = apply_rotary_pos_emb(query, sincos) key = tf.transpose(key, (0, 2, 1, 3)) query = tf.transpose(query, (0, 2, 1, 3)) diff --git a/tests/models/gptj/test_modeling_tf_gptj.py b/tests/models/gptj/test_modeling_tf_gptj.py index 7ae7de3b672ec..07ad27d570253 100644 --- a/tests/models/gptj/test_modeling_tf_gptj.py +++ b/tests/models/gptj/test_modeling_tf_gptj.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime import unittest from transformers import AutoTokenizer, GPTJConfig, is_tf_available @@ -359,10 +358,9 @@ def test_resize_token_embeddings(self): @require_tf -# @tooslow +@tooslow # Marked as @tooslow due to GPU OOM -- but still useful to run locally. Requires ~39GB of RAM. class TFGPTJModelLanguageGenerationTest(unittest.TestCase): - def test_lm_generate_gptj(self): model = TFGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", from_pt=True) input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog @@ -424,10 +422,13 @@ def test_batch_left_padding(self): inputs_non_padded = tokenizer(sentences[0], return_tensors="tf") output_non_padded = model.generate(**inputs_non_padded, do_sample=False, num_beams=2) num_paddings = ( - shape_list(inputs_non_padded["input_ids"])[-1] - tf.reduce_sum(tf.cast(inputs["attention_mask"][-1], tf.int64)).numpy() + shape_list(inputs_non_padded["input_ids"])[-1] + - tf.reduce_sum(tf.cast(inputs["attention_mask"][-1], tf.int64)).numpy() ) inputs_padded = tokenizer(sentences[1], return_tensors="tf") - output_padded = model.generate(**inputs_padded, do_sample=False, num_beams=2, max_length=model.config.max_length - num_paddings) + output_padded = model.generate( + **inputs_padded, do_sample=False, num_beams=2, max_length=model.config.max_length - num_paddings + ) non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True) padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True) self.assertListEqual(expected_output_sentences, [non_padded_sentence, padded_sentence]) From bbd7c4b3334b599e9600197532015b9df1d1fe1a Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 1 Jul 2022 17:09:26 +0000 Subject: [PATCH 3/3] add missing test config option --- tests/models/gptj/test_modeling_tf_gptj.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/gptj/test_modeling_tf_gptj.py b/tests/models/gptj/test_modeling_tf_gptj.py index 07ad27d570253..ec6c15d3f6d64 100644 --- a/tests/models/gptj/test_modeling_tf_gptj.py +++ b/tests/models/gptj/test_modeling_tf_gptj.py @@ -47,6 +47,7 @@ def __init__(self, parent): self.use_mc_token_ids = True self.vocab_size = 99 self.hidden_size = 32 + self.rotary_dim = 4 self.num_hidden_layers = 5 self.num_attention_heads = 4 self.intermediate_size = 37 @@ -102,6 +103,7 @@ def prepare_config_and_inputs(self): bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, + rotary_dim=self.rotary_dim, return_dict=True, )