Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF port of ESM #19587

Merged
merged 23 commits into from Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f77c6fd
Partial TF port for ESM model
Oct 4, 2022
8a2449b
Add ESM-TF tests
Rocketknight1 Oct 10, 2022
ddcb48d
Add the various imports for TF-ESM
Rocketknight1 Oct 10, 2022
10cac9f
TF weight conversion almost ready
Rocketknight1 Oct 10, 2022
4f6eff2
Stop ignoring the decoder weights in PT
Rocketknight1 Oct 12, 2022
9ce1acd
Add tests and lots of fixes
Rocketknight1 Oct 13, 2022
8a89db8
fix-copies
Rocketknight1 Oct 13, 2022
c68f3ef
Fix imports, add model docs
Rocketknight1 Oct 13, 2022
f074cd7
Add get_vocab() to tokenizer
Rocketknight1 Oct 13, 2022
840df9e
Fix vocab links for pretrained files
Rocketknight1 Oct 13, 2022
d89bef0
Allow multiple inputs with a sep
Rocketknight1 Oct 13, 2022
345b360
Use EOS as SEP token because ESM vocab lacks SEP
Rocketknight1 Oct 14, 2022
c3cc44f
Correctly return special tokens mask from ESM tokenizer
Rocketknight1 Oct 14, 2022
5dbf24c
make fixup
Rocketknight1 Oct 14, 2022
e659a9f
Stop testing unsupported embedding resizing
Rocketknight1 Oct 14, 2022
7ebbf2d
Handle TF bias correctly
Rocketknight1 Oct 14, 2022
9f10249
Skip all models with slow tokenizers in the token classification test
Rocketknight1 Oct 14, 2022
357877a
Fixing the batch/unbatcher of pipelines to accomodate the `None` being
Narsil Oct 14, 2022
f5fbfb9
Fixing pipeline bug caused by slow tokenizer being different.
Narsil Oct 14, 2022
7f85040
Update src/transformers/models/esm/modeling_tf_esm.py
Rocketknight1 Oct 17, 2022
82d39c1
Update src/transformers/models/esm/modeling_tf_esm.py
Rocketknight1 Oct 17, 2022
068374f
Update src/transformers/models/esm/modeling_tf_esm.py
Rocketknight1 Oct 17, 2022
ca2a52f
Update set_input_embeddings and the copyright notices
Rocketknight1 Oct 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Expand Up @@ -243,7 +243,7 @@ Flax), PyTorch, and/or TensorFlow.
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ |
| Encoder decoder | ❌ | ❌ | ✅ | ✅ | ✅ |
| ERNIE | ❌ | ❌ | ✅ | ❌ | ❌ |
| ESM | ✅ | ❌ | ✅ | | ❌ |
| ESM | ✅ | ❌ | ✅ | | ❌ |
| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ |
| FlauBERT | ✅ | ❌ | ✅ | ✅ | ❌ |
| FLAVA | ❌ | ❌ | ✅ | ❌ | ❌ |
Expand Down
20 changes: 20 additions & 0 deletions docs/source/en/model_doc/esm.mdx
Expand Up @@ -107,3 +107,23 @@ and [Matt](https://huggingface.co/Rocketknight1).

[[autodoc]] EsmForTokenClassification
- forward

## TFEsmModel

[[autodoc]] TFEsmModel
- call

## TFEsmForMaskedLM

[[autodoc]] TFEsmForMaskedLM
- call

## TFEsmForSequenceClassification

[[autodoc]] TFEsmForSequenceClassification
- call

## TFEsmForTokenClassification

[[autodoc]] TFEsmForTokenClassification
- call
18 changes: 18 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -2462,6 +2462,16 @@
]
)
_import_structure["models.encoder_decoder"].append("TFEncoderDecoderModel")
_import_structure["models.esm"].extend(
[
"ESM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFEsmForMaskedLM",
"TFEsmForSequenceClassification",
"TFEsmForTokenClassification",
"TFEsmModel",
"TFEsmPreTrainedModel",
]
)
_import_structure["models.flaubert"].extend(
[
"TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -5134,6 +5144,14 @@
TFElectraPreTrainedModel,
)
from .models.encoder_decoder import TFEncoderDecoderModel
from .models.esm import (
ESM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFEsmForMaskedLM,
TFEsmForSequenceClassification,
TFEsmForTokenClassification,
TFEsmModel,
TFEsmPreTrainedModel,
)
from .models.flaubert import (
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFFlaubertForMultipleChoice,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Expand Up @@ -47,6 +47,7 @@
("distilbert", "TFDistilBertModel"),
("dpr", "TFDPRQuestionEncoder"),
("electra", "TFElectraModel"),
("esm", "TFEsmModel"),
("flaubert", "TFFlaubertModel"),
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
("gpt2", "TFGPT2Model"),
Expand Down Expand Up @@ -129,6 +130,7 @@
("ctrl", "TFCTRLLMHeadModel"),
("distilbert", "TFDistilBertForMaskedLM"),
("electra", "TFElectraForMaskedLM"),
("esm", "TFEsmForMaskedLM"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("funnel", "TFFunnelForMaskedLM"),
("gpt2", "TFGPT2LMHeadModel"),
Expand Down Expand Up @@ -223,6 +225,7 @@
("deberta-v2", "TFDebertaV2ForMaskedLM"),
("distilbert", "TFDistilBertForMaskedLM"),
("electra", "TFElectraForMaskedLM"),
("esm", "TFEsmForMaskedLM"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("funnel", "TFFunnelForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
Expand Down Expand Up @@ -273,6 +276,7 @@
("deberta-v2", "TFDebertaV2ForSequenceClassification"),
("distilbert", "TFDistilBertForSequenceClassification"),
("electra", "TFElectraForSequenceClassification"),
("esm", "TFEsmForSequenceClassification"),
("flaubert", "TFFlaubertForSequenceClassification"),
("funnel", "TFFunnelForSequenceClassification"),
("gpt2", "TFGPT2ForSequenceClassification"),
Expand Down Expand Up @@ -346,6 +350,7 @@
("deberta-v2", "TFDebertaV2ForTokenClassification"),
("distilbert", "TFDistilBertForTokenClassification"),
("electra", "TFElectraForTokenClassification"),
("esm", "TFEsmForTokenClassification"),
("flaubert", "TFFlaubertForTokenClassification"),
("funnel", "TFFunnelForTokenClassification"),
("layoutlm", "TFLayoutLMForTokenClassification"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Expand Up @@ -122,6 +122,7 @@
),
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("esm", ("EsmTokenizer", None)),
("flaubert", ("FlaubertTokenizer", None)),
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
("fsmt", ("FSMTTokenizer", None)),
Expand Down
32 changes: 31 additions & 1 deletion src/transformers/models/esm/__init__.py
Expand Up @@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available


_import_structure = {
Expand All @@ -40,6 +40,21 @@
"EsmPreTrainedModel",
]

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_esm"] = [
"TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFEsmForMaskedLM",
"TFEsmForSequenceClassification",
"TFEsmForTokenClassification",
"TFEsmModel",
"TFEsmPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_esm import ESM_PRETRAINED_CONFIG_ARCHIVE_MAP, EsmConfig
Expand All @@ -60,6 +75,21 @@
EsmPreTrainedModel,
)

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_esm import (
TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFEsmForMaskedLM,
TFEsmForSequenceClassification,
TFEsmForTokenClassification,
TFEsmModel,
TFEsmPreTrainedModel,
)


else:
import sys
Expand Down
86 changes: 35 additions & 51 deletions src/transformers/models/esm/modeling_esm.py
Expand Up @@ -42,12 +42,14 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "facebook/esm-1b"
_CHECKPOINT_FOR_DOC = "Rocketknight1/esm2_t6_8M_UR50D"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need an update :-)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, all of these will be moved to facebook before the next release!

_CONFIG_FOR_DOC = "EsmConfig"
_TOKENIZER_FOR_DOC = "EsmTokenizer"

ESM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/esm-1b",
"Rocketknight1/esm2_t6_8M_UR50D",
"Rocketknight1/esm2_t12_35M_UR50D",
# This is not a complete list of all ESM models!
# See all ESM models at https://huggingface.co/models?filter=esm
]

Expand Down Expand Up @@ -115,7 +117,6 @@ class EsmEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)

if config.emb_layer_norm_before:
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Expand Down Expand Up @@ -658,15 +659,6 @@ def _init_weights(self, module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

def update_keys_to_ignore(self, config, del_keys_to_ignore):
"""Remove some keys from ignore list"""
if not config.tie_word_embeddings:
# must make a new list, or the class variable gets modified!
self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
self._keys_to_ignore_on_load_missing = [
k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
]


ESM_START_DOCSTRING = r"""

Expand Down Expand Up @@ -907,8 +899,7 @@ def forward(

@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
class EsmForMaskedLM(EsmPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]

def __init__(self, config):
Expand All @@ -923,9 +914,6 @@ def __init__(self, config):
self.esm = EsmModel(config, add_pooling_layer=False)
self.lm_head = EsmLMHead(config)

# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])

self.init_weights()

def get_output_embeddings(self):
Expand All @@ -944,17 +932,17 @@ def set_output_embeddings(self, new_embeddings):
)
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -1009,17 +997,13 @@ def __init__(self, config):
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))

# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def forward(self, features, **kwargs):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)

# project back to size of vocabulary with bias
x = self.decoder(x)

x = self.decoder(x) + self.bias
return x


Expand Down Expand Up @@ -1052,15 +1036,15 @@ def __init__(self, config):
)
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Expand Down Expand Up @@ -1148,15 +1132,15 @@ def __init__(self, config):
)
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down