Skip to content

Commit

Permalink
TF port of ESM (#19587)
Browse files Browse the repository at this point in the history
* Partial TF port for ESM model

* Add ESM-TF tests

* Add the various imports for TF-ESM

* TF weight conversion almost ready

* Stop ignoring the decoder weights in PT

* Add tests and lots of fixes

* fix-copies

* Fix imports, add model docs

* Add get_vocab() to tokenizer

* Fix vocab links for pretrained files

* Allow multiple inputs with a sep

* Use EOS as SEP token because ESM vocab lacks SEP

* Correctly return special tokens mask from ESM tokenizer

* make fixup

* Stop testing unsupported embedding resizing

* Handle TF bias correctly

* Skip all models with slow tokenizers in the token classification test

* Fixing the batch/unbatcher of pipelines to accomodate the `None` being

passed around.

* Fixing pipeline bug caused by slow tokenizer  being different.

* Update src/transformers/models/esm/modeling_tf_esm.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/esm/modeling_tf_esm.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/esm/modeling_tf_esm.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update set_input_embeddings and the copyright notices

Co-authored-by: Your Name <you@example.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
  • Loading branch information
4 people committed Oct 17, 2022
1 parent d7754c4 commit 3b3024d
Show file tree
Hide file tree
Showing 15 changed files with 1,898 additions and 84 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Expand Up @@ -243,7 +243,7 @@ Flax), PyTorch, and/or TensorFlow.
| ELECTRA | | | | | |
| Encoder decoder | | | | | |
| ERNIE | | | | | |
| ESM | | | | | |
| ESM | | | | | |
| FairSeq Machine-Translation | | | | | |
| FlauBERT | | | | | |
| FLAVA | | | | | |
Expand Down
20 changes: 20 additions & 0 deletions docs/source/en/model_doc/esm.mdx
Expand Up @@ -107,3 +107,23 @@ and [Matt](https://huggingface.co/Rocketknight1).

[[autodoc]] EsmForTokenClassification
- forward

## TFEsmModel

[[autodoc]] TFEsmModel
- call

## TFEsmForMaskedLM

[[autodoc]] TFEsmForMaskedLM
- call

## TFEsmForSequenceClassification

[[autodoc]] TFEsmForSequenceClassification
- call

## TFEsmForTokenClassification

[[autodoc]] TFEsmForTokenClassification
- call
18 changes: 18 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -2462,6 +2462,16 @@
]
)
_import_structure["models.encoder_decoder"].append("TFEncoderDecoderModel")
_import_structure["models.esm"].extend(
[
"ESM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFEsmForMaskedLM",
"TFEsmForSequenceClassification",
"TFEsmForTokenClassification",
"TFEsmModel",
"TFEsmPreTrainedModel",
]
)
_import_structure["models.flaubert"].extend(
[
"TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -5134,6 +5144,14 @@
TFElectraPreTrainedModel,
)
from .models.encoder_decoder import TFEncoderDecoderModel
from .models.esm import (
ESM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFEsmForMaskedLM,
TFEsmForSequenceClassification,
TFEsmForTokenClassification,
TFEsmModel,
TFEsmPreTrainedModel,
)
from .models.flaubert import (
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFFlaubertForMultipleChoice,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Expand Up @@ -47,6 +47,7 @@
("distilbert", "TFDistilBertModel"),
("dpr", "TFDPRQuestionEncoder"),
("electra", "TFElectraModel"),
("esm", "TFEsmModel"),
("flaubert", "TFFlaubertModel"),
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
("gpt2", "TFGPT2Model"),
Expand Down Expand Up @@ -129,6 +130,7 @@
("ctrl", "TFCTRLLMHeadModel"),
("distilbert", "TFDistilBertForMaskedLM"),
("electra", "TFElectraForMaskedLM"),
("esm", "TFEsmForMaskedLM"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("funnel", "TFFunnelForMaskedLM"),
("gpt2", "TFGPT2LMHeadModel"),
Expand Down Expand Up @@ -223,6 +225,7 @@
("deberta-v2", "TFDebertaV2ForMaskedLM"),
("distilbert", "TFDistilBertForMaskedLM"),
("electra", "TFElectraForMaskedLM"),
("esm", "TFEsmForMaskedLM"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("funnel", "TFFunnelForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
Expand Down Expand Up @@ -273,6 +276,7 @@
("deberta-v2", "TFDebertaV2ForSequenceClassification"),
("distilbert", "TFDistilBertForSequenceClassification"),
("electra", "TFElectraForSequenceClassification"),
("esm", "TFEsmForSequenceClassification"),
("flaubert", "TFFlaubertForSequenceClassification"),
("funnel", "TFFunnelForSequenceClassification"),
("gpt2", "TFGPT2ForSequenceClassification"),
Expand Down Expand Up @@ -346,6 +350,7 @@
("deberta-v2", "TFDebertaV2ForTokenClassification"),
("distilbert", "TFDistilBertForTokenClassification"),
("electra", "TFElectraForTokenClassification"),
("esm", "TFEsmForTokenClassification"),
("flaubert", "TFFlaubertForTokenClassification"),
("funnel", "TFFunnelForTokenClassification"),
("layoutlm", "TFLayoutLMForTokenClassification"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Expand Up @@ -122,6 +122,7 @@
),
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("esm", ("EsmTokenizer", None)),
("flaubert", ("FlaubertTokenizer", None)),
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
("fsmt", ("FSMTTokenizer", None)),
Expand Down
32 changes: 31 additions & 1 deletion src/transformers/models/esm/__init__.py
Expand Up @@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available


_import_structure = {
Expand All @@ -40,6 +40,21 @@
"EsmPreTrainedModel",
]

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_esm"] = [
"TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFEsmForMaskedLM",
"TFEsmForSequenceClassification",
"TFEsmForTokenClassification",
"TFEsmModel",
"TFEsmPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_esm import ESM_PRETRAINED_CONFIG_ARCHIVE_MAP, EsmConfig
Expand All @@ -60,6 +75,21 @@
EsmPreTrainedModel,
)

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_esm import (
TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFEsmForMaskedLM,
TFEsmForSequenceClassification,
TFEsmForTokenClassification,
TFEsmModel,
TFEsmPreTrainedModel,
)


else:
import sys
Expand Down
88 changes: 36 additions & 52 deletions src/transformers/models/esm/modeling_esm.py
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 Facebook and The HuggingFace Inc. team. All rights reserved.
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -42,12 +42,14 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "facebook/esm-1b"
_CHECKPOINT_FOR_DOC = "Rocketknight1/esm2_t6_8M_UR50D"
_CONFIG_FOR_DOC = "EsmConfig"
_TOKENIZER_FOR_DOC = "EsmTokenizer"

ESM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/esm-1b",
"Rocketknight1/esm2_t6_8M_UR50D",
"Rocketknight1/esm2_t12_35M_UR50D",
# This is not a complete list of all ESM models!
# See all ESM models at https://huggingface.co/models?filter=esm
]

Expand Down Expand Up @@ -115,7 +117,6 @@ class EsmEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)

if config.emb_layer_norm_before:
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Expand Down Expand Up @@ -658,15 +659,6 @@ def _init_weights(self, module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

def update_keys_to_ignore(self, config, del_keys_to_ignore):
"""Remove some keys from ignore list"""
if not config.tie_word_embeddings:
# must make a new list, or the class variable gets modified!
self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
self._keys_to_ignore_on_load_missing = [
k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
]


ESM_START_DOCSTRING = r"""
Expand Down Expand Up @@ -907,8 +899,7 @@ def forward(

@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
class EsmForMaskedLM(EsmPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]

def __init__(self, config):
Expand All @@ -923,9 +914,6 @@ def __init__(self, config):
self.esm = EsmModel(config, add_pooling_layer=False)
self.lm_head = EsmLMHead(config)

# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])

self.init_weights()

def get_output_embeddings(self):
Expand All @@ -944,17 +932,17 @@ def set_output_embeddings(self, new_embeddings):
)
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -1009,17 +997,13 @@ def __init__(self, config):
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))

# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def forward(self, features, **kwargs):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)

# project back to size of vocabulary with bias
x = self.decoder(x)

x = self.decoder(x) + self.bias
return x


Expand Down Expand Up @@ -1052,15 +1036,15 @@ def __init__(self, config):
)
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Expand Down Expand Up @@ -1148,15 +1132,15 @@ def __init__(self, config):
)
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down

0 comments on commit 3b3024d

Please sign in to comment.