diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 4ca9246941d49..8465f1909841f 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -243,7 +243,7 @@ Flax), PyTorch, and/or TensorFlow. | ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ | | Encoder decoder | ❌ | ❌ | ✅ | ✅ | ✅ | | ERNIE | ❌ | ❌ | ✅ | ❌ | ❌ | -| ESM | ✅ | ❌ | ✅ | ❌ | ❌ | +| ESM | ✅ | ❌ | ✅ | ✅ | ❌ | | FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ | | FlauBERT | ✅ | ❌ | ✅ | ✅ | ❌ | | FLAVA | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/esm.mdx b/docs/source/en/model_doc/esm.mdx index d2fc949781bc2..fef638c8b4f35 100644 --- a/docs/source/en/model_doc/esm.mdx +++ b/docs/source/en/model_doc/esm.mdx @@ -107,3 +107,23 @@ and [Matt](https://huggingface.co/Rocketknight1). [[autodoc]] EsmForTokenClassification - forward + +## TFEsmModel + +[[autodoc]] TFEsmModel + - call + +## TFEsmForMaskedLM + +[[autodoc]] TFEsmForMaskedLM + - call + +## TFEsmForSequenceClassification + +[[autodoc]] TFEsmForSequenceClassification + - call + +## TFEsmForTokenClassification + +[[autodoc]] TFEsmForTokenClassification + - call diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 263a9a27cc22c..58a44a50f1d18 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2462,6 +2462,16 @@ ] ) _import_structure["models.encoder_decoder"].append("TFEncoderDecoderModel") + _import_structure["models.esm"].extend( + [ + "ESM_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFEsmForMaskedLM", + "TFEsmForSequenceClassification", + "TFEsmForTokenClassification", + "TFEsmModel", + "TFEsmPreTrainedModel", + ] + ) _import_structure["models.flaubert"].extend( [ "TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -5134,6 +5144,14 @@ TFElectraPreTrainedModel, ) from .models.encoder_decoder import TFEncoderDecoderModel + from .models.esm import ( + ESM_PRETRAINED_MODEL_ARCHIVE_LIST, + TFEsmForMaskedLM, + TFEsmForSequenceClassification, + TFEsmForTokenClassification, + TFEsmModel, + TFEsmPreTrainedModel, + ) from .models.flaubert import ( TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TFFlaubertForMultipleChoice, diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 95fe3f8aaaec0..8bb7b5595f355 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -47,6 +47,7 @@ ("distilbert", "TFDistilBertModel"), ("dpr", "TFDPRQuestionEncoder"), ("electra", "TFElectraModel"), + ("esm", "TFEsmModel"), ("flaubert", "TFFlaubertModel"), ("funnel", ("TFFunnelModel", "TFFunnelBaseModel")), ("gpt2", "TFGPT2Model"), @@ -129,6 +130,7 @@ ("ctrl", "TFCTRLLMHeadModel"), ("distilbert", "TFDistilBertForMaskedLM"), ("electra", "TFElectraForMaskedLM"), + ("esm", "TFEsmForMaskedLM"), ("flaubert", "TFFlaubertWithLMHeadModel"), ("funnel", "TFFunnelForMaskedLM"), ("gpt2", "TFGPT2LMHeadModel"), @@ -223,6 +225,7 @@ ("deberta-v2", "TFDebertaV2ForMaskedLM"), ("distilbert", "TFDistilBertForMaskedLM"), ("electra", "TFElectraForMaskedLM"), + ("esm", "TFEsmForMaskedLM"), ("flaubert", "TFFlaubertWithLMHeadModel"), ("funnel", "TFFunnelForMaskedLM"), ("layoutlm", "TFLayoutLMForMaskedLM"), @@ -273,6 +276,7 @@ ("deberta-v2", "TFDebertaV2ForSequenceClassification"), ("distilbert", "TFDistilBertForSequenceClassification"), ("electra", "TFElectraForSequenceClassification"), + ("esm", "TFEsmForSequenceClassification"), ("flaubert", "TFFlaubertForSequenceClassification"), ("funnel", "TFFunnelForSequenceClassification"), ("gpt2", "TFGPT2ForSequenceClassification"), @@ -346,6 +350,7 @@ ("deberta-v2", "TFDebertaV2ForTokenClassification"), ("distilbert", "TFDistilBertForTokenClassification"), ("electra", "TFElectraForTokenClassification"), + ("esm", "TFEsmForTokenClassification"), ("flaubert", "TFFlaubertForTokenClassification"), ("funnel", "TFFunnelForTokenClassification"), ("layoutlm", "TFLayoutLMForTokenClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index e29a5b19ddc07..43aaa01153efe 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -122,6 +122,7 @@ ), ("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)), ("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("esm", ("EsmTokenizer", None)), ("flaubert", ("FlaubertTokenizer", None)), ("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)), ("fsmt", ("FSMTTokenizer", None)), diff --git a/src/transformers/models/esm/__init__.py b/src/transformers/models/esm/__init__.py index 9d1d0687726e2..b394b7a0c52ed 100644 --- a/src/transformers/models/esm/__init__.py +++ b/src/transformers/models/esm/__init__.py @@ -17,7 +17,7 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available _import_structure = { @@ -40,6 +40,21 @@ "EsmPreTrainedModel", ] +try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_tf_esm"] = [ + "TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFEsmForMaskedLM", + "TFEsmForSequenceClassification", + "TFEsmForTokenClassification", + "TFEsmModel", + "TFEsmPreTrainedModel", + ] + if TYPE_CHECKING: from .configuration_esm import ESM_PRETRAINED_CONFIG_ARCHIVE_MAP, EsmConfig @@ -60,6 +75,21 @@ EsmPreTrainedModel, ) + try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_tf_esm import ( + TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST, + TFEsmForMaskedLM, + TFEsmForSequenceClassification, + TFEsmForTokenClassification, + TFEsmModel, + TFEsmPreTrainedModel, + ) + else: import sys diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 337f7e3716542..8a880d3b4b4ec 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 Facebook and The HuggingFace Inc. team. All rights reserved. +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -42,12 +42,14 @@ logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "facebook/esm-1b" +_CHECKPOINT_FOR_DOC = "Rocketknight1/esm2_t6_8M_UR50D" _CONFIG_FOR_DOC = "EsmConfig" _TOKENIZER_FOR_DOC = "EsmTokenizer" ESM_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "facebook/esm-1b", + "Rocketknight1/esm2_t6_8M_UR50D", + "Rocketknight1/esm2_t12_35M_UR50D", + # This is not a complete list of all ESM models! # See all ESM models at https://huggingface.co/models?filter=esm ] @@ -115,7 +117,6 @@ class EsmEmbeddings(nn.Module): def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) if config.emb_layer_norm_before: self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -658,15 +659,6 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def update_keys_to_ignore(self, config, del_keys_to_ignore): - """Remove some keys from ignore list""" - if not config.tie_word_embeddings: - # must make a new list, or the class variable gets modified! - self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore] - self._keys_to_ignore_on_load_missing = [ - k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore - ] - ESM_START_DOCSTRING = r""" @@ -907,8 +899,7 @@ def forward( @add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING) class EsmForMaskedLM(EsmPreTrainedModel): - _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] - _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] + _keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): @@ -923,9 +914,6 @@ def __init__(self, config): self.esm = EsmModel(config, add_pooling_layer=False) self.lm_head = EsmLMHead(config) - # The LM head weights require special treatment only when they are tied with the word embeddings - self.update_keys_to_ignore(config, ["lm_head.decoder.weight"]) - self.init_weights() def get_output_embeddings(self): @@ -944,17 +932,17 @@ def set_output_embeddings(self, new_embeddings): ) def forward( self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1009,17 +997,13 @@ def __init__(self, config): self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - def forward(self, features, **kwargs): x = self.dense(features) x = gelu(x) x = self.layer_norm(x) # project back to size of vocabulary with bias - x = self.decoder(x) - + x = self.decoder(x) + self.bias return x @@ -1052,15 +1036,15 @@ def __init__(self, config): ) def forward( self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1148,15 +1132,15 @@ def __init__(self, config): ) def forward( self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): diff --git a/src/transformers/models/esm/modeling_tf_esm.py b/src/transformers/models/esm/modeling_tf_esm.py new file mode 100644 index 0000000000000..929af109e4642 --- /dev/null +++ b/src/transformers/models/esm/modeling_tf_esm.py @@ -0,0 +1,1399 @@ +# coding=utf-8 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch ESM model.""" + +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf +from tensorflow.keras.activations import gelu +from tensorflow.keras.layers import Dense, Dropout, Embedding, Layer, LayerNormalization + +from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPoolingAndCrossAttentions, + TFMaskedLMOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFPreTrainedModel, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + get_initializer, + get_tf_activation, + shape_list, + unpack_inputs, +) +from ...tf_utils import stable_softmax +from ...utils import logging +from .configuration_esm import EsmConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "Rocketknight1/esm2_t6_8M_UR50D" +_CONFIG_FOR_DOC = "EsmConfig" +_TOKENIZER_FOR_DOC = "EsmTokenizer" + +TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "Rocketknight1/esm2_t6_8M_UR50D", + "Rocketknight1/esm2_t12_35M_UR50D", + # This is not a complete list of all ESM models! + # See all ESM models at https://huggingface.co/models?filter=esm +] + + +def rotate_half(x): + x1, x2 = tf.split(x, 2, axis=-1) + return tf.concat((-x2, x1), axis=-1) + + +def apply_rotary_pos_emb(x, cos, sin): + cos = cos[:, :, : tf.shape(x)[-2], :] + sin = sin[:, :, : tf.shape(x)[-2], :] + + return (x * cos) + (rotate_half(x) * sin) + + +class TFRotaryEmbedding(Layer): + """ + Rotary position embeddings based on those in + [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation + matrices which depend on their relative positions. + """ + + def __init__(self, dim: int, name=None): + super().__init__(name=name) + # Matt: The PyTorch version of this layer does a lot of work to cache values, but we just rely on TF compilation + # and/or XLA to sort out constants like that. It actually may not seem like this layer needs to be stateful at + # all when we benefit from TF compilation, but it does. The reason is that self.inv_freq is a buffer in the + # original implementation, but all the shared ESM checkpoints were trained with fp16 params. This means that + # the inv_freq tensor was stored as a float16, and we need to replicate those lower-precision values or our + # models give different outputs from the original. + self.dim = dim + + def build(self, input_shape): + super().build(input_shape) + self.inv_freq = self.add_weight( + "inv_freq", shape=(self.dim // 2,), dtype=tf.float32, initializer=get_initializer(1.0) + ) + self.inv_freq.assign( + 1.0 / (10000 ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim)) + ) + + def _compute_cos_sin(self, x, seq_dimension=2): + seq_len = tf.shape(x)[seq_dimension] + + t = tf.range(seq_len, dtype=self.inv_freq.dtype) + freqs = tf.einsum("i, j -> ij", t, self.inv_freq) # Outer multiplication + emb = tf.concat((freqs, freqs), axis=-1)[None, None, :, :] + + return tf.cos(emb), tf.sin(emb) + + def call(self, q: tf.Tensor, k: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + cos_emb, sin_emb = self._compute_cos_sin(k, seq_dimension=-2) + + return ( + apply_rotary_pos_emb(q, cos_emb, sin_emb), + apply_rotary_pos_emb(k, cos_emb, sin_emb), + ) + + +class TFEsmEmbeddings(Layer): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config, name=None): + super().__init__(name=name) + self.word_embeddings = Embedding( + config.vocab_size, + config.hidden_size, + embeddings_initializer=get_initializer(config.initializer_range), + name="word_embeddings", + ) + self.position_embeddings = Embedding( + config.max_position_embeddings, + config.hidden_size, + embeddings_initializer=get_initializer(config.initializer_range), + name="position_embeddings", + ) + + if config.emb_layer_norm_before: + self.layer_norm = LayerNormalization(epsilon=config.layer_norm_eps) + else: + self.layer_norm = None + # Matt: I think this line was copied incorrectly from BERT, disabling for now + # self.dropout = Dropout(config.hidden_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.position_ids = tf.range(config.max_position_embeddings)[None, :] + + self.padding_idx = config.pad_token_id + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + self.vocab_size = config.vocab_size + + def call( + self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if inputs_embeds is None: + # Note: tf.gather, on which the embedding layer is based, won't check positive out of bound + # indices on GPU, returning zeros instead. This is a dangerous silent behavior. + tf.debugging.assert_less( + input_ids, + tf.cast(self.vocab_size, dtype=input_ids.dtype), + message=( + "input_ids must be smaller than the embedding layer's input dimension (got" + f" {tf.math.reduce_max(input_ids)} >= {self.vocab_size})" + ), + ) + inputs_embeds = self.word_embeddings(input_ids) + + # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an + # embedding_scale factor here. + embeddings = inputs_embeds + + # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout + # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, + # masked tokens are treated as if they were selected for input dropout and zeroed out. + # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by + # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). + # This is analogous to the way that dropout layers scale down outputs during evaluation when not + # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). + if self.token_dropout: + embeddings = tf.where((input_ids == self.mask_token_id)[:, :, None], 0.0, embeddings) + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + src_lengths = tf.cast(tf.reduce_sum(attention_mask, axis=-1), tf.float32) + masked_tokens = input_ids == self.mask_token_id + mask_ratio_observed = tf.math.count_nonzero(masked_tokens, dtype=tf.float32, axis=-1) / src_lengths + embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + if attention_mask is not None: + embeddings = embeddings * tf.cast(tf.expand_dims(attention_mask, -1), embeddings.dtype) + # Matt: I think this line was copied incorrectly from BERT, disabling it for now. + # embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: tf.Tensor + + Returns: tf.Tensor + """ + input_shape = shape_list(inputs_embeds)[:-1] + sequence_length = input_shape[1] + + position_ids = tf.range( + start=self.padding_idx + 1, limit=sequence_length + self.padding_idx + 1, dtype=tf.int64 + ) + return tf.broadcast_to(tf.expand_dims(position_ids, 0), input_shape) + + +class TFEsmSelfAttention(Layer): + def __init__(self, config, position_embedding_type=None, name=None): + super().__init__(name=name) + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = Dense(self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key") + self.value = Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + + self.dropout = Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + self.rotary_embeddings = None + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size, + embeddings_initializer=get_initializer(config.initializer_range), + ) + elif self.position_embedding_type == "rotary": + self.rotary_embeddings = TFRotaryEmbedding(dim=self.attention_head_size, name="rotary_embeddings") + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: + new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size] + x = tf.reshape(x, new_x_shape) + return tf.transpose(x, perm=(0, 2, 1, 3)) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: Optional[tf.Tensor] = None, + head_mask: Optional[tf.Tensor] = None, + encoder_hidden_states: Optional[tf.Tensor] = None, + encoder_attention_mask: Optional[tf.Tensor] = None, + past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, + output_attentions: Optional[bool] = False, + training: bool = False, + ) -> Tuple[tf.Tensor]: + + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim). + # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent, + # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original + # ESM code and fix rotary embeddings. + query_layer = query_layer * self.attention_head_size**-0.5 + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + if self.position_embedding_type == "rotary": + query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = shape_list(hidden_states)[1] + position_ids_l = tf.expand_dims(tf.range(seq_length, dtype=tf.int64), -1) + position_ids_r = tf.expand_dims(tf.range(seq_length, dtype=tf.int64), 0) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = tf.cast(positional_embedding, query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = tf.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in EsmModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = attention_probs @ value_layer + + context_layer = tf.transpose(context_layer, perm=(0, 2, 1, 3)) + new_context_layer_shape = shape_list(context_layer)[:-2] + [self.all_head_size] + context_layer = tf.reshape(context_layer, new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class TFEsmSelfOutput(Layer): + def __init__(self, config, name=None): + super().__init__(name=name) + self.dense = Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = Dropout(config.hidden_dropout_prob) + + def call(self, hidden_states, input_tensor, training=False): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states += input_tensor + return hidden_states + + +class TFEsmAttention(Layer): + def __init__(self, config, name=None): + super().__init__(name=name) + self.self = TFEsmSelfAttention(config, name="self") + self.output_layer = TFEsmSelfOutput(config, name="output") + self.pruned_heads = set() + self.LayerNorm = LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + training=False, + ): + hidden_states_ln = self.LayerNorm(hidden_states) + self_outputs = self.self( + hidden_states_ln, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + training, + ) + attention_output = self.output_layer(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Esm +class TFEsmIntermediate(tf.keras.layers.Layer): + def __init__(self, config: EsmConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = tf.keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +class TFEsmOutput(Layer): + def __init__(self, config, name=None): + super().__init__(name=name) + self.dense = Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = Dropout(config.hidden_dropout_prob) + + def call(self, hidden_states, input_tensor, training=False): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states += input_tensor + return hidden_states + + +class TFEsmLayer(Layer): + def __init__(self, config, name=None): + super().__init__(name=name) + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = TFEsmAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFEsmAttention(config) + self.intermediate = TFEsmIntermediate(config, name="intermediate") + self.output_layer = TFEsmOutput(config, name="output") + self.LayerNorm = LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + + def call( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + training=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + training=training, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise AttributeError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated" + " with cross-attention layers by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layernorm_output = self.LayerNorm(attention_output) + intermediate_output = self.intermediate(hidden_states=layernorm_output) + layer_output = self.output_layer( + hidden_states=intermediate_output, input_tensor=attention_output, training=training + ) + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + +class TFEsmEncoder(Layer): + def __init__(self, config, name=None): + super().__init__(name=name) + self.config = config + self.layer = [TFEsmLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + self.emb_layer_norm_after = LayerNormalization(epsilon=config.layer_norm_eps, name="emb_layer_norm_after") + + def call( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + training=False, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + training, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if self.emb_layer_norm_after: + hidden_states = self.emb_layer_norm_after(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Esm +class TFEsmPooler(Layer): + def __init__(self, config: EsmConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = tf.keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + +class TFEsmPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = EsmConfig + base_model_prefix = "esm" + + +ESM_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Keras [Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it as a + regular Keras model and refer to the TF/Keras documentation for all matters related to general usage and behavior. + + Parameters: + config ([`EsmConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ESM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`EsmTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.", + ESM_START_DOCSTRING, +) +class TFEsmMainLayer(Layer): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config, add_pooling_layer=True, name=None, **kwargs): + super().__init__(name=name, **kwargs) + + self.config = config + self.is_decoder = config.is_decoder + + self.embeddings = TFEsmEmbeddings(config, name="embeddings") + self.encoder = TFEsmEncoder(config, name="encoder") + self.pooler = TFEsmPooler(config, name="pooler") if add_pooling_layer else None + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value: tf.Variable): + self.embeddings.word_embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + raise NotImplementedError + + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: + + if not self.config.is_decoder: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_key_values_length = 0 + past_key_values = [None] * len(self.encoder.layer) + else: + past_key_values_length = shape_list(past_key_values[0][0])[-2] + + if attention_mask is None: + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) + + embedding_output = self.embeddings( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + training=training, + ) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask_shape = shape_list(attention_mask) + + mask_seq_length = seq_length + past_key_values_length + # Copied from `modeling_tf_t5.py` + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask * attention_mask[:, None, :] + attention_mask_shape = shape_list(extended_attention_mask) + extended_attention_mask = tf.reshape( + extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) + ) + if past_key_values[0] is not None: + # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] + else: + extended_attention_mask = tf.reshape( + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + + # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 + if self.is_decoder and encoder_attention_mask is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None + + if not return_dict: + return ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + return TFBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.", + ESM_START_DOCSTRING, +) +class TFEsmModel(TFEsmPreTrainedModel): + def __init__(self, config: EsmConfig, add_pooling_layer=True, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.esm = TFEsmMainLayer(config, add_pooling_layer=add_pooling_layer, name="esm") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + """ + outputs = self.esm( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return outputs + + @tf.function( + input_signature=[ + { + "input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"), + "attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"), + } + ] + ) + def serving(self, inputs): + output = self.call(inputs) + + return self.serving_output(output) + + def serving_output( + self, output: TFBaseModelOutputWithPoolingAndCrossAttentions + ) -> TFBaseModelOutputWithPoolingAndCrossAttentions: + output_cache = self.config.use_cache and self.config.is_decoder + pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None + if not (self.config.output_attentions and self.config.add_cross_attention): + cross_attns = None + + return TFBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=output.last_hidden_state, + pooler_output=output.pooler_output, + past_key_values=pkv, + hidden_states=hs, + attentions=attns, + cross_attentions=cross_attns, + ) + + +@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING) +class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss): + _keys_to_ignore_on_load_missing = [r"position_ids"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm") + self.lm_head = TFEsmLMHead(config, name="lm_head") + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + def get_lm_head(self): + return self.lm_head + + @unpack_inputs + @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + mask="", + ) + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + labels: Optional[Union[np.ndarray, tf.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + kwargs (`Dict[str, any]`, optional, defaults to *{}*): + Used to hide legacy arguments that have been deprecated. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + masked_lm_loss = self.hf_compute_loss(labels=labels, logits=prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return TFMaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output + def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput: + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns) + + @tf.function( + input_signature=[ + { + "input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"), + "attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"), + } + ] + ) + def serving(self, inputs): + output = self.call(inputs) + + return self.serving_output(output) + + +class TFEsmLMHead(Layer): + """ESM Head for masked language modeling.""" + + def __init__(self, config, name=None): + super().__init__(name=name) + self.dense = Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + self.layer_norm = LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + + self.decoder = Dense( + config.vocab_size, + use_bias=False, + kernel_initializer=get_initializer(config.initializer_range), + name="decoder", + ) + self.vocab_size = config.vocab_size + + def build(self, input_shape): + super().build(input_shape) + # Separate bias to match the PT model and allow weight cross-loading to work + # Put it in the build so it gets the right name when adding it as a weight + self.bias = self.add_weight("bias", shape=(self.vocab_size,), initializer="zeros", trainable=True) + + def get_bias(self): + return {"bias": self.bias} + + def call(self, features): + x = self.dense(features) + x = gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + x = x + self.bias + return x + + +@add_start_docstrings( + """ + ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + ESM_START_DOCSTRING, +) +class TFEsmForSequenceClassification(TFEsmPreTrainedModel, TFSequenceClassificationLoss): + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm") + self.classifier = TFEsmClassificationHead(config, name="classifier") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + labels: Optional[Union[np.ndarray, tf.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output + def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput: + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns) + + @tf.function( + input_signature=[ + { + "input_ids": tf.TensorSpec((None, None, None), tf.int64, name="input_ids"), + "attention_mask": tf.TensorSpec((None, None, None), tf.int64, name="attention_mask"), + } + ] + ) + def serving(self, inputs): + output = self.call(inputs) + + return self.serving_output(output) + + +@add_start_docstrings( + """ + ESM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ESM_START_DOCSTRING, +) +class TFEsmForTokenClassification(TFEsmPreTrainedModel, TFTokenClassificationLoss): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm") + self.dropout = Dropout(config.hidden_dropout_prob) + self.classifier = Dense(config.num_labels, name="classifier") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + labels: Optional[Union[np.ndarray, tf.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output, training=training) + logits = self.classifier(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output + def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput: + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFTokenClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns) + + @tf.function( + input_signature=[ + { + "input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"), + "attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"), + } + ] + ) + def serving(self, inputs): + output = self.call(inputs) + + return self.serving_output(output) + + +class TFEsmClassificationHead(Layer): + """Head for sentence-level classification tasks.""" + + def __init__(self, config, name=None): + super().__init__(name=name) + self.dense = Dense( + config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.dropout = Dropout(config.hidden_dropout_prob) + self.out_proj = Dense( + config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + activation="linear", + name="out_proj", + ) + + def call(self, features, training=False): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x, training=training) + x = self.dense(x) + x = self.dropout(x, training=training) + x = self.out_proj(x) + return x + + +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: tf.Tensor x: + + Returns: tf.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = tf.cast(input_ids != padding_idx, tf.int64) + incremental_indices = (tf.cumsum(mask, axis=1) + past_key_values_length) * mask + return incremental_indices + padding_idx diff --git a/src/transformers/models/esm/tokenization_esm.py b/src/transformers/models/esm/tokenization_esm.py index 0512ccf864578..e83d11ddb42cc 100644 --- a/src/transformers/models/esm/tokenization_esm.py +++ b/src/transformers/models/esm/tokenization_esm.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright Facebook and The HuggingFace Inc. team. All rights reserved. +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,12 +27,18 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "facebook/esm1b": "https://huggingface.co/facebook/esm1b/resolve/main/vocab.txt", + "Rocketknight1/esm2_t6_8M_UR50D": ( + "https://huggingface.co/Rocketknight1/esm2_t6_8M_UR50D/resolve/main/vocab.txt" + ), + "Rocketknight1/esm2_t12_35M_UR50D": ( + "https://huggingface.co/Rocketknight1/esm2_t12_35M_UR50D/resolve/main/vocab.txt" + ), }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "facebook/esm1b": 1024, + "Rocketknight1/esm2_t6_8M_UR50D": 1024, + "Rocketknight1/esm2_t12_35M_UR50D": 1024, } @@ -77,6 +83,9 @@ def _tokenize(self, text, **kwargs): def get_vocab_size(self, with_added_tokens=False): return len(self._id_to_token) + def get_vocab(self): + return {token: i for i, token in enumerate(self.all_tokens)} + def token_to_id(self, token: str) -> int: return self._token_to_id.get(token, self._token_to_id.get(self.unk_token)) @@ -86,11 +95,42 @@ def id_to_token(self, index: int) -> str: def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.eos_token_id] + cls = [self.cls_token_id] + sep = [self.eos_token_id] # No sep token in ESM vocabulary + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`List[int]`): + List of ids of the first sequence. + token_ids_1 (`List[int]`, *optional*): + List of ids of the second sequence. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + + return [1 if token in self.all_special_ids else 0 for token in token_ids_0] + mask = [1] + ([0] * len(token_ids_0)) + [1] if token_ids_1 is not None: - raise ValueError("Multiple input sentences are not supported!") - cls_: List[int] = [self.cls_token_id] - eos_: List[int] = [self.eos_token_id] - return cls_ + token_ids_0 + eos_ + mask += [0] * len(token_ids_1) + [1] + return mask def save_vocabulary(self, save_directory, filename_prefix): vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt") diff --git a/src/transformers/pipelines/fill_mask.py b/src/transformers/pipelines/fill_mask.py index f461f6faa2af6..c4f547d9c286a 100644 --- a/src/transformers/pipelines/fill_mask.py +++ b/src/transformers/pipelines/fill_mask.py @@ -138,7 +138,7 @@ def postprocess(self, model_outputs, top_k=5, target_ids=None): # For multi masks though, the other [MASK] would be removed otherwise # making the output look odd, so we add them back sequence = self.tokenizer.decode(tokens, skip_special_tokens=single_mask) - proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode(p), "sequence": sequence} + proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode([p]), "sequence": sequence} row.append(proposition) result.append(row) if single_mask: diff --git a/src/transformers/pipelines/pt_utils.py b/src/transformers/pipelines/pt_utils.py index 8eb7ac7798269..455fdc34f7a07 100644 --- a/src/transformers/pipelines/pt_utils.py +++ b/src/transformers/pipelines/pt_utils.py @@ -83,7 +83,10 @@ def loader_batch_item(self): elif isinstance(element[0], np.ndarray): loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element) continue - if isinstance(element[self._loader_batch_index], torch.Tensor): + if element is None: + # This can happen for optional data that get passed around + loader_batched[k] = None + elif isinstance(element[self._loader_batch_index], torch.Tensor): # Take correct batch data, but make it looked like batch_size=1 # For compatibility with other methods within transformers diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index dbfe14c7f633c..f72d2e2a9c35e 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -1142,6 +1142,44 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +ESM_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class TFEsmForMaskedLM(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFEsmForSequenceClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFEsmForTokenClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFEsmModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFEsmPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/esm/test_modeling_esm.py b/tests/models/esm/test_modeling_esm.py index dce9cb69e804a..9733223e39c73 100644 --- a/tests/models/esm/test_modeling_esm.py +++ b/tests/models/esm/test_modeling_esm.py @@ -240,6 +240,14 @@ def test_create_position_ids_from_inputs_embeds(self): self.assertEqual(position_ids.shape, expected_positions.shape) self.assertTrue(torch.all(torch.eq(position_ids, expected_positions))) + @unittest.skip("Esm does not support embedding resizing") + def test_resize_embeddings_untied(self): + pass + + @unittest.skip("Esm does not support embedding resizing") + def test_resize_tokens_embeddings(self): + pass + @require_torch class EsmModelIntegrationTest(TestCasePlus): @@ -270,24 +278,3 @@ def test_inference_no_head(self): [[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]] ) self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) - - def test_lm_head_ignore_keys(self): - from copy import deepcopy - - keys_to_ignore_on_save_tied = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] - keys_to_ignore_on_save_untied = [r"lm_head.decoder.bias"] - config = EsmConfig.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") - config_tied = deepcopy(config) - config_tied.tie_word_embeddings = True - config_untied = deepcopy(config) - config_untied.tie_word_embeddings = False - for cls in [EsmForMaskedLM]: - model = cls(config_tied) - self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_tied, cls) - - # the keys should be different when embeddings aren't tied - model = cls(config_untied) - self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_untied, cls) - - # test that saving works with updated ignore keys - just testing that it doesn't fail - model.save_pretrained(self.get_auto_remove_tmp_dir()) diff --git a/tests/models/esm/test_modeling_tf_esm.py b/tests/models/esm/test_modeling_tf_esm.py new file mode 100644 index 0000000000000..9a9aa8b1ac8ea --- /dev/null +++ b/tests/models/esm/test_modeling_tf_esm.py @@ -0,0 +1,287 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import EsmConfig, is_tf_available +from transformers.testing_utils import require_tf, slow + +from ...test_configuration_common import ConfigTester +from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask + + +if is_tf_available(): + import numpy + import tensorflow as tf + + from transformers.models.esm.modeling_tf_esm import ( + TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST, + TFEsmForMaskedLM, + TFEsmForSequenceClassification, + TFEsmForTokenClassification, + TFEsmModel, + ) + + +# copied from tests.test_modeling_tf_roberta +class TFEsmModelTester: + def __init__( + self, + parent, + ): + self.parent = parent + self.batch_size = 13 + self.seq_length = 7 + self.is_training = True + self.use_input_mask = True + self.use_labels = True + self.vocab_size = 99 + self.hidden_size = 32 + self.num_hidden_layers = 5 + self.num_attention_heads = 4 + self.intermediate_size = 37 + self.hidden_act = "gelu" + self.hidden_dropout_prob = 0.1 + self.attention_probs_dropout_prob = 0.1 + self.max_position_embeddings = 512 + self.type_vocab_size = 16 + self.type_sequence_label_size = 2 + self.initializer_range = 0.02 + self.num_labels = 3 + self.num_choices = 4 + self.scope = None + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = EsmConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + pad_token_id=1, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + ) + + return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + config.is_decoder = True + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): + model = TFEsmModel(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask} + result = model(inputs) + + inputs = [input_ids, input_mask] + result = model(inputs) + + result = model(input_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_as_decoder( + self, + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + + model = TFEsmModel(config=config) + inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + } + result = model(inputs) + + inputs = [input_ids, input_mask] + result = model(inputs, encoder_hidden_states=encoder_hidden_states) + + # Also check the case where encoder outputs are not passed + result = model(input_ids, attention_mask=input_mask) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_masked_lm( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TFEsmForMaskedLM(config=config) + result = model([input_ids, input_mask]) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_for_token_classification( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = TFEsmForTokenClassification(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask} + result = model(inputs) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_tf +class TFEsmModelTest(TFModelTesterMixin, unittest.TestCase): + + all_model_classes = ( + ( + TFEsmModel, + TFEsmForMaskedLM, + TFEsmForSequenceClassification, + TFEsmForTokenClassification, + ) + if is_tf_available() + else () + ) + test_head_masking = False + test_onnx = False + + def setUp(self): + self.model_tester = TFEsmModelTester(self) + self.config_tester = ConfigTester(self, config_class=EsmConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + """Test the base model""" + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_as_decoder(self): + """Test the base model as a decoder (of an encoder-decoder architecture) + + is_deocder=True + cross_attention + pass encoder outputs + """ + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = TFEsmModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + @unittest.skip("Protein models do not support embedding resizing.") + def test_resize_token_embeddings(self): + pass + + @unittest.skip("Protein models do not support embedding resizing.") + def test_save_load_after_resize_token_embeddings(self): + pass + + +@require_tf +class TFEsmModelIntegrationTest(unittest.TestCase): + @slow + def test_inference_masked_lm(self): + model = TFEsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") + + input_ids = tf.constant([[0, 1, 2, 3, 4, 5]]) + output = model(input_ids)[0] + expected_shape = [1, 6, 33] + self.assertEqual(list(output.numpy().shape), expected_shape) + # compare the actual values for a slice. + expected_slice = tf.constant( + [[[15.0963, -6.6414, -1.1346], [-0.2209, -9.9633, 4.2082], [-1.6045, -10.0011, 1.5882]]] + ) + self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4)) + + @slow + def test_inference_no_head(self): + model = TFEsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") + + input_ids = tf.constant([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) + output = model(input_ids)[0] + # compare the actual values for a slice. + expected_slice = tf.constant( + [ + [ + [0.144337, 0.541198, 0.32479298], + [0.30328932, 0.00519154, 0.31089523], + [0.32273883, -0.24992886, 0.34143737], + ] + ] + ) + self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4)) diff --git a/tests/pipelines/test_pipelines_token_classification.py b/tests/pipelines/test_pipelines_token_classification.py index ff86e7106ae5b..2e44448e1336a 100644 --- a/tests/pipelines/test_pipelines_token_classification.py +++ b/tests/pipelines/test_pipelines_token_classification.py @@ -44,6 +44,8 @@ def get_test_pipeline(self, model, tokenizer, feature_extractor): def run_pipeline_test(self, token_classifier, _): model = token_classifier.model tokenizer = token_classifier.tokenizer + if not tokenizer.is_fast: + return # Slow tokenizers do not return offsets mappings, so this test will fail outputs = token_classifier("A simple string") self.assertIsInstance(outputs, list)