diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e79009b9b791f4..7484f80e700f2b 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2110,6 +2110,7 @@ [ "FlaxXGLMModel", "FlaxXGLMPreTrainedModel", + "FlaxXGLMForCausalLM", ] ) else: @@ -3847,7 +3848,7 @@ FlaxWav2Vec2Model, FlaxWav2Vec2PreTrainedModel, ) - from .models.xglm import FlaxXGLMModel, FlaxXGLMPreTrainedModel + from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel else: # Import the same objects as dummies to get them in the namespace. # They will raise an import error if the user tries to instantiate / use them. diff --git a/src/transformers/models/xglm/__init__.py b/src/transformers/models/xglm/__init__.py index b23a654843b7e9..581180a94c4876 100644 --- a/src/transformers/models/xglm/__init__.py +++ b/src/transformers/models/xglm/__init__.py @@ -42,6 +42,7 @@ _import_structure["modeling_flax_xglm"] = [ "FlaxXGLMModel", "FlaxXGLMPreTrainedModel", + "FlaxXGLMForCausalLM", ] @@ -56,7 +57,7 @@ from .modeling_xglm import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMPreTrainedModel if is_flax_available(): - from .modeling_xglm import FlaxXGLMModel, FlaxXGLMPreTrainedModel + from .modeling_xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel else: diff --git a/src/transformers/models/xglm/modeling_flax_xglm.py b/src/transformers/models/xglm/modeling_flax_xglm.py index 6cdfeddff3c329..caef00bb8605b3 100644 --- a/src/transformers/models/xglm/modeling_flax_xglm.py +++ b/src/transformers/models/xglm/modeling_flax_xglm.py @@ -20,6 +20,8 @@ from functools import partial from typing import Callable, Optional, Tuple +import numpy as np + import flax.linen as nn import jax import jax.numpy as jnp @@ -31,13 +33,8 @@ from ...file_utils import add_start_docstrings, replace_return_docstrings from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxCausalLMOutputWithCrossAttentions, - FlaxSeq2SeqLMOutput, - FlaxSeq2SeqModelOutput, - FlaxSeq2SeqQuestionAnsweringModelOutput, - FlaxSeq2SeqSequenceClassifierOutput, ) from ...modeling_flax_utils import ( ACT2FN, @@ -219,6 +216,20 @@ """ +def create_sinusoidal_positions(n_pos, dim, padding_idx=1): + half_dim = dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = np.exp(np.arange(half_dim) * -emb) + emb = np.expand_dims(np.arange(n_pos), 1) * np.expand_dims(emb, 0) + emb = np.concatenate([np.sin(emb), np.cos(emb)], 1) + emb = np.reshape(emb, (n_pos, dim)) + + if padding_idx is not None: + emb[padding_idx, :] = 0 + + return jnp.array(emb) + + def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. @@ -231,6 +242,7 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_ return shifted_input_ids +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->XGLM class FlaxXGLMAttention(nn.Module): config: XGLMConfig embed_dim: int @@ -368,7 +380,7 @@ def __call__( attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None @@ -396,116 +408,6 @@ def __call__( return attn_output, attn_weights -class FlaxXGLMEncoderLayer(nn.Module): - config: XGLMConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self) -> None: - self.embed_dim = self.config.d_model - self.self_attn = FlaxXGLMAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.encoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_fn = ACT2FN[self.config.activation_function] - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - self.fc1 = nn.Dense( - self.config.encoder_ffn_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) - self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) - - def __call__( - self, - hidden_states: jnp.ndarray, - attention_mask: jnp.ndarray, - output_attentions: bool = True, - deterministic: bool = True, - ) -> Tuple[jnp.ndarray]: - residual = hidden_states - hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) - - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = self.fc2(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -class FlaxXGLMEncoderLayerCollection(nn.Module): - config: XGLMConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layers = [ - FlaxXGLMEncoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers) - ] - self.layerdrop = self.config.encoder_layerdrop - - def __call__( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ): - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - for encoder_layer in self.layers: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if not deterministic and (dropout_probability < self.layerdrop): # skip the layer - layer_outputs = (None, None) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions, - deterministic, - ) - hidden_states = layer_outputs[0] - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - outputs = (hidden_states, all_hidden_states, all_attentions) - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions - ) - - class FlaxXGLMDecoderLayer(nn.Module): config: XGLMConfig dtype: jnp.dtype = jnp.float32 @@ -515,34 +417,37 @@ def setup(self) -> None: self.self_attn = FlaxXGLMAttention( config=self.config, embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, + num_heads=self.config.attention_heads, dropout=self.config.attention_dropout, causal=True, dtype=self.dtype, ) + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.activation_fn = ACT2FN[self.config.activation_function] self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) - self.encoder_attn = FlaxXGLMAttention( - config=self.config, - embed_dim=self.embed_dim, - num_heads=self.config.decoder_attention_heads, - dropout=self.config.attention_dropout, - dtype=self.dtype, - ) - self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) + if self.config.add_cross_attention: + self.encoder_attn = FlaxXGLMAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.fc1 = nn.Dense( - self.config.encoder_ffn_dim, + self.config.ffn_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) - self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + # Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer.__call__ def __call__( self, hidden_states: jnp.ndarray, @@ -554,6 +459,7 @@ def __call__( deterministic: bool = True, ) -> Tuple[jnp.ndarray]: residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention hidden_states, self_attn_weights = self.self_attn( @@ -561,13 +467,13 @@ def __call__( ) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, @@ -575,16 +481,15 @@ def __call__( ) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) # Fully Connected residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.fc2(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states,) @@ -600,9 +505,9 @@ class FlaxXGLMDecoderLayerCollection(nn.Module): def setup(self): self.layers = [ - FlaxXGLMDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers) + FlaxXGLMDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_layers) ] - self.layerdrop = self.config.decoder_layerdrop + self.layerdrop = self.config.layerdrop def __call__( self, @@ -650,150 +555,55 @@ def __call__( if output_hidden_states: all_hidden_states += (hidden_states,) - outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] - - if not return_dict: - return tuple(v for v in outputs if v is not None) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - -class FlaxXGLMClassificationHead(nn.Module): - """Head for sentence-level classification tasks.""" - - config: XGLMConfig - inner_dim: int - num_classes: int - pooler_dropout: float - dtype: jnp.dtype = jnp.float32 + outputs = (hidden_states, all_hidden_states, all_self_attns, all_cross_attentions) + return outputs - def setup(self): - self.dense = nn.Dense( - self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - self.dropout = nn.Dropout(rate=self.pooler_dropout) - self.out_proj = nn.Dense( - self.num_classes, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - ) + # if not return_dict: + # return tuple(v for v in outputs if v is not None) - def __call__(self, hidden_states: jnp.ndarray, deterministic: bool): - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.dense(hidden_states) - hidden_states = jnp.tanh(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.out_proj(hidden_states) - return hidden_states + # return FlaxBaseModelOutputWithPastAndCrossAttentions( + # last_hidden_state=hidden_states, + # hidden_states=all_hidden_states, + # attentions=all_self_attns, + # cross_attentions=all_cross_attentions, + # ) -class FlaxXGLMEncoder(nn.Module): +class FlaxXGLMModule(nn.Module): config: XGLMConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation - embed_tokens: Optional[nn.Embed] = None def setup(self): self.dropout_layer = nn.Dropout(rate=self.config.dropout) embed_dim = self.config.d_model self.padding_idx = self.config.pad_token_id - self.max_source_positions = self.config.max_position_embeddings - self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 - - if self.embed_tokens is None: - self.embed_tokens = nn.Embed( - self.config.vocab_size, - embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - ) + self.max_target_positions = self.config.max_position_embeddings + self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 - # XGLM is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models don't have this hack - self.offset = 2 - self.embed_positions = nn.Embed( - self.config.max_position_embeddings + self.offset, + self.embed_tokens = nn.Embed( + self.config.vocab_size, embed_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) - self.layers = FlaxXGLMEncoderLayerCollection(self.config, self.dtype) - self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype) - - def __call__( - self, - input_ids, - attention_mask, - position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - input_shape = input_ids.shape - input_ids = input_ids.reshape(-1, input_shape[-1]) - - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - embed_pos = self.embed_positions(position_ids + self.offset) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) - - outputs = self.layers( - hidden_states, - attention_mask, - deterministic=deterministic, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return outputs - - return FlaxBaseModelOutput( - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class FlaxXGLMDecoder(nn.Module): - config: XGLMConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - embed_tokens: Optional[nn.Embed] = None - - def setup(self): - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - - embed_dim = self.config.d_model - self.padding_idx = self.config.pad_token_id - self.max_target_positions = self.config.max_position_embeddings - self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 - - if self.embed_tokens is None: - self.embed_tokens = nn.Embed( - self.config.vocab_size, - embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - ) # XGLM is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 2 - self.embed_positions = nn.Embed( - self.config.max_position_embeddings + self.offset, - embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std), + # TODO: padding idx should be zero + self.embed_positions = create_sinusoidal_positions( + self.config.max_position_embeddings + self.offset, embed_dim ) - self.layers = FlaxXGLMDecoderLayerCollection(self.config, self.dtype) - self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype) + self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def _create_pos_ids(self, input_ids, position_ids): + mask_ne = jnp.not_equal(input_ids, 1).astype("i4") + mask_eq = jnp.equal(input_ids, 1).astype("i4") + padding_idx = self.config.pad_token_id + + position_ids = (position_ids * mask_ne - padding_idx) + (mask_eq * self.offset) + return position_ids def __call__( self, @@ -814,11 +624,11 @@ def __call__( inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale # embed positions - positions = self.embed_positions(position_ids + self.offset) + # position_ids = self._create_pos_ids(input_ids, position_ids) + position_ids = position_ids + self.offset + positions = jnp.take(self.embed_positions, position_ids, axis=0) hidden_states = inputs_embeds + positions - hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) outputs = self.layers( @@ -833,83 +643,18 @@ def __call__( return_dict=return_dict, ) + last_hidden_states = outputs[0] + last_hidden_states = self.layer_norm(last_hidden_states) + if not return_dict: - return outputs + outputs = (last_hidden_states,) + outputs[1:] + return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -class FlaxXGLMModule(nn.Module): - config: XGLMConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.shared = nn.Embed( - self.config.vocab_size, - self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.init_std), - ) - - self.encoder = FlaxXGLMEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) - self.decoder = FlaxXGLMDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared) - - def _get_encoder_module(self): - return self.encoder - - def _get_decoder_module(self): - return self.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return FlaxSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, + last_hidden_state=last_hidden_states, + hidden_states=outputs[1], + attentions=outputs[2], + cross_attentions=outputs[3], ) @@ -932,30 +677,29 @@ def __init__( def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") - # make sure initialization pass will work for FlaxXGLMForSequenceClassificationModule - input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id) attention_mask = jnp.ones_like(input_ids) - decoder_input_ids = input_ids - decoder_attention_mask = jnp.ones_like(input_ids) - - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( - rngs, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - )["params"] + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False) - def init_cache(self, batch_size, max_length, encoder_outputs): + return module_init_outputs["params"] + + def init_cache(self, batch_size, max_length): r""" Args: batch_size (`int`): @@ -969,66 +713,41 @@ def init_cache(self, batch_size, max_length, encoder_outputs): encoder. Used in the cross-attention of the decoder. """ # init input variables to retrieve cache - decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape - ) - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) + input_ids = jnp.ones((batch_size, max_length), dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) init_variables = self.module.init( - jax.random.PRNGKey(0), - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - decoder_position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - init_cache=True, - method=_decoder_forward, # we only need to call the decoder to init the cache + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True ) return unfreeze(init_variables["cache"]) - @add_start_docstrings(XGLM_ENCODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=XGLMConfig) - def encode( + def __call__( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, position_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, + past_key_values: dict = None, dropout_rng: PRNGKey = None, ): - r""" - Returns: - - Example:: - - >>> from transformers import XGLMTokenizer, FlaxXGLMForConditionalGeneration - - >>> model = FlaxXGLMForConditionalGeneration.from_pretrained('facebook/xglm-564M') - >>> tokenizer = XGLMTokenizer.from_pretrained('facebook/xglm-564M') - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors='np') - >>> encoder_outputs = model.encode(**inputs) - """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict + if encoder_hidden_states is not None and encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # prepare encoder inputs if attention_mask is None: attention_mask = jnp.ones_like(input_ids) if position_ids is None: @@ -1036,690 +755,132 @@ def encode( position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): - encode_module = module._get_encoder_module() - return encode_module(input_ids, attention_mask, position_ids, **kwargs) + inputs = {"params": params or self.params} - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - method=_encoder_forward, - ) - - @add_start_docstrings(XGLM_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=XGLMConfig) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: dict = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: dict = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example:: - - >>> from transformers import XGLMTokenizer, FlaxXGLMForConditionalGeneration - - >>> model = FlaxXGLMForConditionalGeneration.from_pretrained('facebook/xglm-564M') - >>> tokenizer = XGLMTokenizer.from_pretrained('facebook/xglm-564M') - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors='np') - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> last_decoder_hidden_states = outputs.last_hidden_state - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxXGLMAttention module + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - outputs = self.module.apply( inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, mutable=mutable, - method=_decoder_forward, ) # add updated cache to model output if past_key_values is not None and return_dict: - outputs, past = outputs - outputs["past_key_values"] = unfreeze(past["cache"]) + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) return outputs elif past_key_values is not None and not return_dict: - outputs, past = outputs - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] return outputs - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - decoder_input_ids: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: dict = None, - dropout_rng: PRNGKey = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # prepare encoder inputs - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) - # prepare decoder inputs - if decoder_input_ids is None: - decoder_input_ids = shift_tokens_right( - input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id - ) - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - if decoder_position_ids is None: - batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - position_ids=jnp.array(position_ids, dtype="i4"), - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - ) - - -@add_start_docstrings( - "The bare XGLM Model transformer outputting raw hidden-states without any specific head on top.", - XGLM_START_DOCSTRING, -) class FlaxXGLMModel(FlaxXGLMPreTrainedModel): - config: XGLMConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation module_class = FlaxXGLMModule -append_call_sample_docstring( - FlaxXGLMModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC -) - - -class FlaxXGLMForConditionalGenerationModule(nn.Module): +class FlaxXGLMForCausalLMModule(nn.Module): config: XGLMConfig - dtype: jnp.dtype = jnp.float32 - bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros + dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): - self.model = FlaxXGLMModule(config=self.config, dtype=self.dtype) + self.model = FlaxXGLMModule(self.config, self.dtype) self.lm_head = nn.Dense( - self.model.shared.num_embeddings, + self.config.vocab_size, use_bias=False, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) - self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) - - def _get_encoder_module(self): - return self.model.encoder - - def _get_decoder_module(self): - return self.model.decoder def __call__( self, input_ids, attention_mask, - decoder_input_ids, - decoder_attention_mask, position_ids, - decoder_position_ids, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, ): + outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - position_ids=position_ids, - decoder_position_ids=decoder_position_ids, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states, + encoder_attention_mask, + deterministic=deterministic, + init_cache=init_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - deterministic=deterministic, ) hidden_states = outputs[0] if self.config.tie_word_embeddings: - shared_embedding = self.model.variables["params"]["shared"]["embedding"] + shared_embedding = self.model.variables["params"]["embed_tokens"]["embedding"] lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) else: lm_logits = self.lm_head(hidden_states) - lm_logits += self.final_logits_bias.astype(self.dtype) - if not return_dict: - output = (lm_logits,) + outputs[1:] - return output + return (lm_logits,) + outputs[1:] - return FlaxSeq2SeqLMOutput( + return FlaxCausalLMOutputWithCrossAttentions( logits=lm_logits, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings( - "The XGLM Model with a language modeling head. Can be used for summarization.", XGLM_START_DOCSTRING -) -class FlaxXGLMForConditionalGeneration(FlaxXGLMPreTrainedModel): - module_class = FlaxXGLMForConditionalGenerationModule - dtype: jnp.dtype = jnp.float32 - - @add_start_docstrings(XGLM_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=XGLMConfig) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, - past_key_values: dict = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - deterministic: bool = True, - params: dict = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example:: - - >>> from transformers import XGLMTokenizer, FlaxXGLMForConditionalGeneration - - >>> model = FlaxXGLMForConditionalGeneration.from_pretrained('facebook/xglm-564M') - >>> tokenizer = XGLMTokenizer.from_pretrained('facebook/xglm-564M') - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=1024, return_tensors='np') - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> logits = outputs.logits - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) +class FlaxXGLMForCausalLM(FlaxXGLMPreTrainedModel): + module_class = FlaxXGLMForCausalLMModule - if decoder_position_ids is None: - if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxXGLMAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): - decoder_module = module._get_decoder_module() - outputs = decoder_module( - decoder_input_ids, - decoder_attention_mask, - decoder_position_ids, - **kwargs, - ) - hidden_states = outputs[0] - - if self.config.tie_word_embeddings: - shared_embedding = module.model.variables["params"]["shared"]["embedding"] - lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) - else: - lm_logits = module.lm_head(hidden_states) - - lm_logits += module.final_logits_bias.astype(self.dtype) - return lm_logits, outputs - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - if past_key_values is None: - lm_logits, decoder_outputs = outputs - else: - (lm_logits, decoder_outputs), past = outputs - - if return_dict: - outputs = FlaxCausalLMOutputWithCrossAttentions( - logits=lm_logits, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - ) - else: - outputs = (lm_logits,) + decoder_outputs[1:] - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - max_length, - attention_mask: Optional[jnp.DeviceArray] = None, - decoder_attention_mask: Optional[jnp.DeviceArray] = None, - encoder_outputs=None, - **kwargs - ): + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): # initializing the cache - batch_size, seq_length = decoder_input_ids.shape + batch_size, seq_length = input_ids.shape - past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + past_key_values = self.init_cache(batch_size, max_length) # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyways. + # But since GPT2 uses a causal mask, those positions are masked anyways. # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if decoder_attention_mask is not None: - position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 - extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) else: position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) return { "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "encoder_attention_mask": attention_mask, - "decoder_attention_mask": extended_attention_mask, - "decoder_position_ids": position_ids, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, } def update_inputs_for_generation(self, model_outputs, model_kwargs): model_kwargs["past_key_values"] = model_outputs.past_key_values - model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 return model_kwargs - - -FLAX_XGLM_CONDITIONAL_GENERATION_DOCSTRING = """ - Returns: - - Summarization example:: - - >>> from transformers import XGLMTokenizer, FlaxXGLMForConditionalGeneration - - >>> model = FlaxXGLMForConditionalGeneration.from_pretrained('facebook/xglm-564M') - >>> tokenizer = XGLMTokenizer.from_pretrained('facebook/xglm-564M') - - >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np') - - >>> # Generate Summary - >>> summary_ids = model.generate(inputs['input_ids']).sequences - >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) - - Mask filling example:: - - >>> from transformers import XGLMTokenizer, FlaxXGLMForConditionalGeneration - >>> tokenizer = XGLMTokenizer.from_pretrained('facebook/xglm-564M') - >>> TXT = "My friends are but they eat too many carbs." - - >>> model = FlaxXGLMForConditionalGeneration.from_pretrained('facebook/xglm-564M') - >>> input_ids = tokenizer([TXT], return_tensors='np')['input_ids'] - >>> logits = model(input_ids).logits - - >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() - >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0) - >>> values, predictions = jax.lax.top_k(probs) - - >>> tokenizer.decode(predictions).split() -""" - -overwrite_call_docstring( - FlaxXGLMForConditionalGeneration, XGLM_INPUTS_DOCSTRING + FLAX_XGLM_CONDITIONAL_GENERATION_DOCSTRING -) -append_replace_return_docstrings( - FlaxXGLMForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC -) - - -class FlaxXGLMForSequenceClassificationModule(nn.Module): - config: XGLMConfig - dtype: jnp.dtype = jnp.float32 - num_labels: Optional[int] = None - - def setup(self): - self.model = FlaxXGLMModule(config=self.config, dtype=self.dtype) - self.classification_head = FlaxXGLMClassificationHead( - config=self.config, - inner_dim=self.config.d_model, - num_classes=self.num_labels if self.num_labels is not None else self.config.num_labels, - pooler_dropout=self.config.classifier_dropout, - ) - - def _get_encoder_module(self): - return self.model.encoder - - def _get_decoder_module(self): - return self.model.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - position_ids=position_ids, - decoder_position_ids=decoder_position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - hidden_states = outputs[0] # last hidden state - - eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0) - - # The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation - if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer: - if len(jnp.unique(eos_mask.sum(1))) > 1: - raise ValueError("All examples must have the same number of tokens.") - - if any(eos_mask.sum(1) == 0): - raise ValueError("There are missing tokens in input_ids") - - # Ensure to keep 1 only for the last token for each example - eos_mask_noised = eos_mask + jnp.arange(eos_mask.shape[1]) * 1e-6 - eos_mask = jnp.where(eos_mask_noised == eos_mask_noised.max(1).reshape(-1, 1), 1, 0) - - sentence_representation = jnp.einsum("ijk, ij -> ijk", hidden_states, eos_mask).sum(1) - logits = self.classification_head(sentence_representation, deterministic=deterministic) - - if not return_dict: - output = (logits,) + outputs[1:] - return output - - return FlaxSeq2SeqSequenceClassifierOutput( - logits=logits, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings( - """ - XGLM model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE - tasks. - """, - XGLM_START_DOCSTRING, -) -class FlaxXGLMForSequenceClassification(FlaxXGLMPreTrainedModel): - module_class = FlaxXGLMForSequenceClassificationModule - dtype = jnp.float32 - - -append_call_sample_docstring( - FlaxXGLMForSequenceClassification, - _TOKENIZER_FOR_DOC, - _CHECKPOINT_FOR_DOC, - FlaxSeq2SeqSequenceClassifierOutput, - _CONFIG_FOR_DOC, -) - - -class FlaxXGLMForQuestionAnsweringModule(nn.Module): - config: XGLMConfig - dtype: jnp.dtype = jnp.float32 - num_labels = 2 - - def setup(self): - self.model = FlaxXGLMModule(config=self.config, dtype=self.dtype) - self.qa_outputs = nn.Dense( - self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) - ) - - def _get_encoder_module(self): - return self.model.encoder - - def _get_decoder_module(self): - return self.model.decoder - - def __call__( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask, - position_ids, - decoder_position_ids, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - ): - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - position_ids=position_ids, - decoder_position_ids=decoder_position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = jnp.split(logits, logits.shape[-1], axis=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if not return_dict: - output = (start_logits, end_logits) + outputs[1:] - return output - - return FlaxSeq2SeqQuestionAnsweringModelOutput( - start_logits=start_logits, - end_logits=end_logits, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings( - """ - XGLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - XGLM_START_DOCSTRING, -) -class FlaxXGLMForQuestionAnswering(FlaxXGLMPreTrainedModel): - module_class = FlaxXGLMForQuestionAnsweringModule - dtype = jnp.float32 - - -append_call_sample_docstring( - FlaxXGLMForQuestionAnswering, - _TOKENIZER_FOR_DOC, - _CHECKPOINT_FOR_DOC, - FlaxSeq2SeqQuestionAnsweringModelOutput, - _CONFIG_FOR_DOC, -) diff --git a/tests/test_modeling_flax_xglm.py b/tests/test_modeling_flax_xglm.py index 8ccb6342144fc1..ec3e7b5ed81f02 100644 --- a/tests/test_modeling_flax_xglm.py +++ b/tests/test_modeling_flax_xglm.py @@ -14,337 +14,342 @@ # limitations under the License. +import tempfile import unittest -from transformers import XGLMConfig, XGLMTokenizer, is_flax_available -from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow +import transformers +from transformers import XGLMConfig, XGLMTokenizer, is_flax_available, is_torch_available +from transformers.testing_utils import ( + is_pt_flax_cross_test, + require_flax, + require_sentencepiece, + require_tokenizers, + slow, +) from .test_configuration_common import ConfigTester -from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor +from .test_generation_flax_utils import FlaxGenerationTesterMixin +from .test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask if is_flax_available(): import numpy as np + import jax import jax.numpy as jnp - from transformers import ( - FlaxXGLMForConditionalGeneration, - FlaxXGLMForQuestionAnswering, - FlaxXGLMForSequenceClassification, - FlaxXGLMModel, + from transformers import FlaxXGLMForCausalLM, FlaxXGLMModel + from transformers.modeling_flax_pytorch_utils import ( + convert_pytorch_state_dict_to_flax, + load_flax_weights_in_pytorch_model, ) + from transformers.models.xglm.modeling_flax_xglm import FlaxXGLMForCausalLM, FlaxXGLMModel + + +if is_torch_available(): + import torch @require_flax class FlaxXGLMModelTester: - config_cls = XGLMConfig - config_updates = {} - hidden_act = "gelu" - def __init__( self, parent, - batch_size=13, + batch_size=14, seq_length=7, is_training=True, - use_labels=False, + use_input_mask=True, + use_labels=True, vocab_size=99, - hidden_size=32, + d_model=32, num_hidden_layers=5, num_attention_heads=4, - intermediate_size=37, - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=20, - eos_token_id=2, - pad_token_id=1, - bos_token_id=0, + ffn_dim=37, + activation_function="gelu", + activation_dropout=0.1, + attention_dropout=0.1, + max_position_embeddings=512, + initializer_range=0.02, + scope=None, ): self.parent = parent self.batch_size = batch_size self.seq_length = seq_length self.is_training = is_training + self.use_input_mask = use_input_mask self.use_labels = use_labels self.vocab_size = vocab_size - self.hidden_size = hidden_size + self.hidden_size = d_model self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads - self.intermediate_size = intermediate_size - - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.ffn_dim = ffn_dim + self.activation_function = activation_function + self.activation_dropout = activation_dropout + self.attention_dropout = attention_dropout self.max_position_embeddings = max_position_embeddings - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id + self.initializer_range = initializer_range + self.scope = None + self.bos_token_id = 0 + self.eos_token_id = 2 + self.pad_token_id = 1 - def prepare_config_and_inputs_for_common(self): - input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size).clip(3, self.vocab_size) - eos_tensor = np.expand_dims(np.array([self.eos_token_id] * self.batch_size), 1) - input_ids = np.concatenate([input_ids, eos_tensor], axis=1) + def prepare_config_and_inputs(self): + input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length], self.vocab_size), 3, self.vocab_size) - decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) - config = self.config_cls( + config = XGLMConfig( vocab_size=self.vocab_size, d_model=self.hidden_size, - encoder_layers=self.num_hidden_layers, - decoder_layers=self.num_hidden_layers, - encoder_attention_heads=self.num_attention_heads, - decoder_attention_heads=self.num_attention_heads, - encoder_ffn_dim=self.intermediate_size, - decoder_ffn_dim=self.intermediate_size, - dropout=self.hidden_dropout_prob, - attention_dropout=self.attention_probs_dropout_prob, + num_layers=self.num_hidden_layers, + attention_heads=self.num_attention_heads, + ffn_dim=self.ffn_dim, + activation_function=self.activation_function, + activation_dropout=self.activation_dropout, + attention_dropout=self.attention_dropout, max_position_embeddings=self.max_position_embeddings, - eos_token_ids=[2], + initializer_range=self.initializer_range, + use_cache=True, bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, - decoder_start_token_id=self.pad_token_id, - **self.config_updates, ) - inputs_dict = prepare_xglm_inputs_dict(config, input_ids, decoder_input_ids) + + return (config, input_ids, input_mask) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, attention_mask = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask} return config, inputs_dict - def check_use_cache_forward(self, model_class_name, config, inputs_dict): - max_decoder_length = 20 - model = model_class_name(config) + def prepare_config_and_inputs_for_decoder(self): + config, input_ids, attention_mask = self.prepare_config_and_inputs() - encoder_outputs = model.encode(inputs_dict["input_ids"]) + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) - decoder_input_ids, decoder_attention_mask = ( - inputs_dict["decoder_input_ids"], - inputs_dict["decoder_attention_mask"], + return ( + config, + input_ids, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, ) - past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs) - decoder_attention_mask = jnp.ones((decoder_input_ids.shape[0], max_decoder_length), dtype="i4") + def check_use_cache_forward(self, model_class_name, config, input_ids, attention_mask): + max_decoder_length = 20 + model = model_class_name(config) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :], - (decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1), + past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length) + attention_mask = jnp.ones((input_ids.shape[0], max_decoder_length), dtype="i4") + + position_ids = jnp.broadcast_to( + jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1) ) - outputs_cache = model.decode( - decoder_input_ids[:, :-1], - encoder_outputs, - decoder_attention_mask=decoder_attention_mask, + outputs_cache = model( + input_ids[:, :-1], + attention_mask=attention_mask, past_key_values=past_key_values, - decoder_position_ids=decoder_position_ids, + position_ids=position_ids, ) - decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4") - outputs_cache_next = model.decode( - decoder_input_ids[:, -1:], - encoder_outputs, - decoder_attention_mask=decoder_attention_mask, + position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4") + outputs_cache_next = model( + input_ids[:, -1:], + attention_mask=attention_mask, past_key_values=outputs_cache.past_key_values, - decoder_position_ids=decoder_position_ids, + position_ids=position_ids, ) - outputs = model.decode(decoder_input_ids, encoder_outputs) + outputs = model(input_ids) diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") - def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict): + def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input_ids, attention_mask): max_decoder_length = 20 model = model_class_name(config) - encoder_outputs = model.encode(inputs_dict["input_ids"]) - - decoder_input_ids, decoder_attention_mask = ( - inputs_dict["decoder_input_ids"], - inputs_dict["decoder_attention_mask"], - ) - - decoder_attention_mask_cache = jnp.concatenate( - [ - decoder_attention_mask, - jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])), - ], + attention_mask_cache = jnp.concatenate( + [attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))], axis=-1, ) - past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs) - decoder_position_ids = jnp.broadcast_to( - jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :], - (decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1), + past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length) + position_ids = jnp.broadcast_to( + jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1) ) - outputs_cache = model.decode( - decoder_input_ids[:, :-1], - encoder_outputs, - decoder_attention_mask=decoder_attention_mask_cache, + outputs_cache = model( + input_ids[:, :-1], + attention_mask=attention_mask_cache, past_key_values=past_key_values, - decoder_position_ids=decoder_position_ids, + position_ids=position_ids, ) - decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4") - outputs_cache_next = model.decode( - decoder_input_ids[:, -1:], - encoder_outputs, + position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4") + outputs_cache_next = model( + input_ids[:, -1:], past_key_values=outputs_cache.past_key_values, - decoder_attention_mask=decoder_attention_mask_cache, - decoder_position_ids=decoder_position_ids, + attention_mask=attention_mask_cache, + position_ids=position_ids, ) - outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask) - + outputs = model(input_ids, attention_mask=attention_mask) diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") -def prepare_xglm_inputs_dict( - config, - input_ids, - decoder_input_ids, - attention_mask=None, - decoder_attention_mask=None, -): - if attention_mask is None: - attention_mask = np.not_equal(input_ids, config.pad_token_id).astype(np.int8) - if decoder_attention_mask is None: - decoder_attention_mask = np.concatenate( - [ - np.ones(decoder_input_ids[:, :1].shape, dtype=np.int8), - np.not_equal(decoder_input_ids[:, 1:], config.pad_token_id).astype(np.int8), - ], - axis=-1, - ) - return { - "input_ids": input_ids, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - } - - @require_flax -class FlaxXGLMModelTest(FlaxModelTesterMixin, unittest.TestCase): - all_model_classes = ( - ( - FlaxXGLMForConditionalGeneration, - FlaxXGLMForQuestionAnswering, - FlaxXGLMForSequenceClassification, - FlaxXGLMModel, - ) - if is_flax_available() - else () - ) - all_generative_model_classes = (FlaxXGLMForConditionalGeneration,) if is_flax_available() else () - is_encoder_decoder = True - test_pruning = False - test_head_masking = False - test_onnx = False +class FlaxXGLMModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase): + + all_model_classes = (FlaxXGLMModel, FlaxXGLMForCausalLM) if is_flax_available() else () + all_generative_model_classes = (FlaxXGLMForCausalLM,) if is_flax_available() else () def setUp(self): self.model_tester = FlaxXGLMModelTester(self) - self.config_tester = ConfigTester(self, config_class=XGLMConfig) - - def test_config(self): - self.config_tester.run_common_tests() def test_use_cache_forward(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: - self.model_tester.check_use_cache_forward(model_class, config, inputs_dict) + for model_class_name in self.all_model_classes: + config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_use_cache_forward(model_class_name, config, input_ids, attention_mask) def test_use_cache_forward_with_attn_mask(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: - self.model_tester.check_use_cache_forward_with_attn_mask(model_class, config, inputs_dict) + for model_class_name in self.all_model_classes: + config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_use_cache_forward_with_attn_mask( + model_class_name, config, input_ids, attention_mask + ) + @slow + def test_batch_generation(self): + tokenizer = XGLMTokenizer.from_pretrained("XGLM", padding_side="left") + inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True) -def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): - """If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" - if a is None and b is None: - return True - try: - if _assert_tensors_equal(a, b, atol=atol): - return True - raise - except Exception: - if len(prefix) > 0: - prefix = f"{prefix}: " - raise AssertionError(f"{prefix}{a} != {b}") + model = FlaxXGLMForCausalLM.from_pretrained("facebook/xglm-564M") + model.config.num_beams = 1 + model.config.do_sample = False + model.config.pad_token_id = model.config.eos_token_id + jit_generate = jax.jit(model.generate) -def _long_tensor(tok_lst): - return np.array(tok_lst, dtype=np.int32) - - -TOLERANCE = 1e-4 - - -@slow -@require_sentencepiece -@require_tokenizers -@require_flax -class FlaxXGLMModelIntegrationTest(unittest.TestCase): - def test_inference_no_head(self): - model = FlaxXGLMModel.from_pretrained("facebook/xglm-564M") - # change to intended input here - input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) - decoder_input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) - inputs_dict = prepare_xglm_inputs_dict(model.config, input_ids, decoder_input_ids) - output = model(**inputs_dict)[0] - expected_shape = (1, 11, 1024) - self.assertEqual(output.shape, expected_shape) - # change to expected output here - expected_slice = np.array( - [[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], - ) - _assert_tensors_equal(output[:, :3, :3], expected_slice, atol=TOLERANCE) - - def test_inference_with_head(self): - model = FlaxXGLMForConditionalGeneration.from_pretrained("facebook/xglm-564M") - # change to intended input here - input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) - decoder_input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) - inputs_dict = prepare_xglm_inputs_dict(model.config, input_ids, decoder_input_ids) - output = model(**inputs_dict)[0] - expected_shape = (1, 11, 1024) - self.assertEqual(output.shape, expected_shape) - # change to expected output here - expected_slice = np.array( - [[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], - ) - _assert_tensors_equal(output[:, :3, :3], expected_slice, atol=TOLERANCE) + output_sequences = jit_generate(inputs["input_ids"], attention_mask=inputs["attention_mask"]).sequences - def test_seq_to_seq_generation(self): - hf = FlaxXGLMForConditionalGeneration.from_pretrained("facebook/xglm-564M") - tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") + output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True) - batch_input = [ - # string 1, - # string 2, - # string 3, - # string 4, + expected_string = [ + "Hello this is a long string of words. I'm going to try to explain what I mean.", + "Hey, I'm not sure if I'm going to be able to do", ] - # The below article tests that we don't add any hypotheses outside of the top n_beams - dct = tok.batch_encode_plus( - batch_input, - max_length=512, - padding="max_length", - truncation_strategy="only_first", - truncation=True, - return_tensors="np", - ) - - hypotheses_batch = hf.generate( - input_ids=dct["input_ids"], - attention_mask=dct["attention_mask"], - num_beams=2, - ) + self.assertListEqual(output_string, expected_string) - EXPECTED = [ - # here expected 1, - # here expected 2, - # here expected 3, - # here expected 4, - ] + # overwrite from common since `attention_mask` in combination + # with `causal_mask` behaves slighly differently + # @is_pt_flax_cross_test + def test_equivalence_pt_to_flax(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - generated = tok.batch_decode( - hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True - ) - assert generated == EXPECTED + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + # prepare inputs + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} + + # load corresponding PyTorch class + pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + + batch_size, seq_length = pt_inputs["input_ids"].shape + rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,)) + for batch_idx, start_index in enumerate(rnd_start_indices): + pt_inputs["attention_mask"][batch_idx, :start_index] = 0 + pt_inputs["attention_mask"][batch_idx, start_index:] = 1 + prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0 + prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1 + pt_model = pt_model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + fx_model = model_class(config, dtype=jnp.float32) + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) + fx_model.params = fx_state + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs).to_tuple() + + fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() + self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + for fx_output, pt_output in zip(fx_outputs, pt_outputs): + self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2) + + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple() + self.assertEqual( + len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" + ) + for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): + self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2) + + # overwrite from common since `attention_mask` in combination + # with `causal_mask` behaves slighly differently + # @is_pt_flax_cross_test + def test_equivalence_flax_to_pt(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + # prepare inputs + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} + + # load corresponding PyTorch class + pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + + pt_model = pt_model_class(config).eval() + pt_model.config.use_cache = False + fx_model = model_class(config, dtype=jnp.float32) + + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + batch_size, seq_length = pt_inputs["input_ids"].shape + rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,)) + for batch_idx, start_index in enumerate(rnd_start_indices): + pt_inputs["attention_mask"][batch_idx, :start_index] = 0 + pt_inputs["attention_mask"][batch_idx, start_index:] = 1 + prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0 + prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1 + + # make sure weights are tied in PyTorch + pt_model.tie_weights() + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs).to_tuple() + + fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() + self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + for fx_output, pt_output in zip(fx_outputs, pt_outputs): + self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2) + + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() + + self.assertEqual( + len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" + ) + for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded): + self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2) + + @slow + def test_model_from_pretrained(self): + for model_class_name in self.all_model_classes: + model = model_class_name.from_pretrained("facebook/xglm-564M") + outputs = model(np.ones((1, 1))) + self.assertIsNotNone(outputs)