Skip to content

Commit

Permalink
Add ESMFold (huggingface#19977)
Browse files Browse the repository at this point in the history
* initial commit

* First draft that gets outputs without crashing!

* Add all the ported openfold dependencies

* testing

* Restructure config files for ESMFold

* Debugging to find output discrepancies

* Mainly style

* Make model runnable without extra deps

* Remove utils and merge them to the modeling file

* Use correct gelu and remove some debug prints

* More cleanup

* Update esm docs

* Update conversion script to support ESMFold properly

* Port some top-level changes from ESMFold repo

* Expand EsmFold docstrings

* Make attention_mask optional (default to all 1s)

* Add inference test for ESMFold

* Use config and not n kwargs

* Add modeling output class

* Remove einops

* Remove chunking in ESM FFN

* Update tests for ESMFold

* Quality

* REpo consistency

* Remove tree dependency from ESMFold

* make fixup

* Add an error in case my structure map function breaks later

* Remove needless code

* Stop auto-casting the LM to float16 so CPU tests pass

* Stop auto-casting the LM to float16 so CPU tests pass

* Final test updates

* Split test file

* Copyright and quality

* Unpin PyTorch to see built doc

* Fix config file to_dict() method

* Add some docstrings to the output

* Skip TF checkpoint tests for ESM until we reupload those

* make fixup

* More docstrings

* Unpin to get even with main

* Flag example to write

Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
  • Loading branch information
2 people authored and amyeroberts committed Nov 1, 2022
1 parent 297d700 commit a573853
Show file tree
Hide file tree
Showing 22 changed files with 6,821 additions and 90 deletions.
27 changes: 22 additions & 5 deletions docs/source/en/model_doc/esm.mdx
Expand Up @@ -14,8 +14,8 @@ specific language governing permissions and limitations under the License.

## Overview
This page provides code and pre-trained weights for Transformer protein language models from Meta AI's Fundamental
AI Research Team, providing the state-of-the-art ESM-2, and the previously released ESM-1b and ESM-1v. Transformer
protein language models were introduced in the paper [Biological structure and function emerge from scaling
AI Research Team, providing the state-of-the-art ESMFold and ESM-2, and the previously released ESM-1b and ESM-1v.
Transformer protein language models were introduced in the paper [Biological structure and function emerge from scaling
unsupervised learning to 250 million protein sequences](https://www.pnas.org/content/118/15/e2016239118) by
Alexander Rives, Joshua Meier, Tom Sercu, Siddharth Goyal, Zeming Lin, Jason Liu, Demi Guo, Myle Ott,
C. Lawrence Zitnick, Jerry Ma, and Rob Fergus.
Expand All @@ -27,6 +27,13 @@ It was released with the paper [Language models of protein sequences at the scal
structure prediction](https://doi.org/10.1101/2022.07.20.500902) by Zeming Lin, Halil Akin, Roshan Rao, Brian Hie,
Zhongkai Zhu, Wenting Lu, Allan dos Santos Costa, Maryam Fazel-Zarandi, Tom Sercu, Sal Candido and Alexander Rives.

Also introduced in this paper was ESMFold. It uses an ESM-2 stem with a head that can predict folded protein
structures with state-of-the-art accuracy. Unlike [AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2),
it relies on the token embeddings from the large pre-trained protein language model stem and does not perform a multiple
sequence alignment (MSA) step at inference time, which means that ESMFold checkpoints are fully "standalone" -
they do not require a database of known protein sequences and structures with associated external query tools
to make predictions, and are much faster as a result.


The abstract from
"Biological structure and function emerge from scaling unsupervised learning to 250
Expand Down Expand Up @@ -63,17 +70,22 @@ order of magnitude faster than AlphaFold2, enabling exploration of the structura
proteins in practical timescales.*




Tips:

- ESM models are trained with a masked language modeling (MLM) objective.

The original code can be found [here](https://github.com/facebookresearch/esm) and was
was developed by the Fundamental AI Research team at Meta AI.
This model was contributed to huggingface by [jasonliu](https://huggingface.co/jasonliu)
ESM-1b, ESM-1v and ESM-2 were contributed to huggingface by [jasonliu](https://huggingface.co/jasonliu)
and [Matt](https://huggingface.co/Rocketknight1).

ESMFold was contributed to huggingface by [Matt](https://huggingface.co/Rocketknight1) and
[Sylvain](https://huggingface.co/sgugger), with a big thank you to Nikita Smetanin, Roshan Rao and Tom Sercu for their
help throughout the process!

The HuggingFace port of ESMFold uses portions of the [openfold](https://github.com/aqlaboratory/openfold) library.
The `openfold` library is licensed under the Apache License 2.0.

## EsmConfig

[[autodoc]] EsmConfig
Expand Down Expand Up @@ -108,6 +120,11 @@ and [Matt](https://huggingface.co/Rocketknight1).
[[autodoc]] EsmForTokenClassification
- forward

## EsmForProteinFolding

[[autodoc]] EsmForProteinFolding
- forward

## TFEsmModel

[[autodoc]] TFEsmModel
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -1265,7 +1265,9 @@
_import_structure["models.esm"].extend(
[
"ESM_PRETRAINED_MODEL_ARCHIVE_LIST",
"EsmFoldPreTrainedModel",
"EsmForMaskedLM",
"EsmForProteinFolding",
"EsmForSequenceClassification",
"EsmForTokenClassification",
"EsmModel",
Expand Down Expand Up @@ -4144,7 +4146,9 @@
)
from .models.esm import (
ESM_PRETRAINED_MODEL_ARCHIVE_LIST,
EsmFoldPreTrainedModel,
EsmForMaskedLM,
EsmForProteinFolding,
EsmForSequenceClassification,
EsmForTokenClassification,
EsmModel,
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/esm/__init__.py
Expand Up @@ -39,6 +39,7 @@
"EsmModel",
"EsmPreTrainedModel",
]
_import_structure["modeling_esmfold"] = ["EsmForProteinFolding", "EsmFoldPreTrainedModel"]

try:
if not is_tf_available():
Expand All @@ -55,7 +56,6 @@
"TFEsmPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_esm import ESM_PRETRAINED_CONFIG_ARCHIVE_MAP, EsmConfig
from .tokenization_esm import EsmTokenizer
Expand All @@ -74,6 +74,7 @@
EsmModel,
EsmPreTrainedModel,
)
from .modeling_esmfold import EsmFoldPreTrainedModel, EsmForProteinFolding

try:
if not is_tf_available():
Expand Down
235 changes: 231 additions & 4 deletions src/transformers/models/esm/configuration_esm.py
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2021 Facebook and The HuggingFace Inc. team. All rights reserved.
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,12 +14,16 @@
# limitations under the License.
""" ESM model configuration"""

from dataclasses import asdict, dataclass
from typing import Optional

from ...configuration_utils import PretrainedConfig
from ...utils import logging


logger = logging.get_logger(__name__)

# TODO Update this
ESM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/esm-1b": "https://huggingface.co/facebook/esm-1b/resolve/main/config.json",
# See all ESM models at https://huggingface.co/models?filter=esm
Expand Down Expand Up @@ -118,9 +122,12 @@ def __init__(
classifier_dropout=None,
emb_layer_norm_before=None,
token_dropout=False,
is_folding_model=False,
esmfold_config=None,
vocab_list=None,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)

self.vocab_size = vocab_size
self.hidden_size = hidden_size
Expand All @@ -138,5 +145,225 @@ def __init__(
self.classifier_dropout = classifier_dropout
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
self.mask_token_id = mask_token_id
self.pad_token_id = pad_token_id
self.is_folding_model = is_folding_model
if is_folding_model:
if esmfold_config is None:
logger.info("No esmfold_config supplied for folding model, using default values.")
esmfold_config = EsmFoldConfig()
elif isinstance(esmfold_config, dict):
esmfold_config = EsmFoldConfig(**esmfold_config)
self.esmfold_config = esmfold_config
if vocab_list is None:
logger.warning("No vocab_list supplied for folding model, assuming the ESM-2 vocabulary!")
self.vocab_list = get_default_vocab_list()
else:
self.vocab_list = vocab_list
else:
self.esmfold_config = None
self.vocab_list = None
if self.esmfold_config is not None and getattr(self.esmfold_config, "use_esm_attn_map", False):
raise ValueError("The HuggingFace port of ESMFold does not support use_esm_attn_map at this time!")

def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = super().to_dict()
if isinstance(self.esmfold_config, EsmFoldConfig):
output["esmfold_config"] = self.esmfold_config.to_dict()
return output


@dataclass
class EsmFoldConfig:
esm_type: str = None
fp16_esm: bool = True
use_esm_attn_map: bool = False
esm_ablate_pairwise: bool = False
esm_ablate_sequence: bool = False
esm_input_dropout: float = 0

embed_aa: bool = True
bypass_lm: bool = False

lddt_head_hid_dim: int = 128
trunk: "TrunkConfig" = None

def __post_init__(self):
if self.trunk is None:
self.trunk = TrunkConfig()
elif isinstance(self.trunk, dict):
self.trunk = TrunkConfig(**self.trunk)

def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = asdict(self)
output["trunk"] = self.trunk.to_dict()
return output


@dataclass
class TrunkConfig:
num_blocks: int = 48
sequence_state_dim: int = 1024
pairwise_state_dim: int = 128
sequence_head_width: int = 32
pairwise_head_width: int = 32
position_bins: int = 32
dropout: float = 0
layer_drop: float = 0
cpu_grad_checkpoint: bool = False
max_recycles: int = 4
chunk_size: Optional[int] = 128
structure_module: "StructureModuleConfig" = None

def __post_init__(self):
if self.structure_module is None:
self.structure_module = StructureModuleConfig()
elif isinstance(self.structure_module, dict):
self.structure_module = StructureModuleConfig(**self.structure_module)

if self.max_recycles <= 0:
raise ValueError(f"`max_recycles` should be positive, got {self.max_recycles}.")
if self.sequence_state_dim % self.sequence_state_dim != 0:
raise ValueError(
"`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got"
f" {self.sequence_state_dim} and {self.sequence_state_dim}."
)
if self.pairwise_state_dim % self.pairwise_state_dim != 0:
raise ValueError(
"`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got"
f" {self.pairwise_state_dim} and {self.pairwise_state_dim}."
)

sequence_num_heads = self.sequence_state_dim // self.sequence_head_width
pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width

if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width:
raise ValueError(
"`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got"
f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}."
)
if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width:
raise ValueError(
"`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got"
f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}."
)
if self.pairwise_state_dim % 2 != 0:
raise ValueError(f"`pairwise_state_dim` should be even, got {self.pairwise_state_dim}.")

if self.dropout >= 0.4:
raise ValueError(f"`dropout` should not be greater than 0.4, got {self.dropout}.")

def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = asdict(self)
output["structure_module"] = self.structure_module.to_dict()
return output


@dataclass
class StructureModuleConfig:
"""
Args:
sequence_dim:
Single representation channel dimension
pairwise_dim:
Pair representation channel dimension
ipa_dim:
IPA hidden channel dimension
resnet_dim:
Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
num_heads_ipa:
Number of IPA heads
num_qk_points:
Number of query/key points to generate during IPA
num_v_points:
Number of value points to generate during IPA
dropout_rate:
Dropout rate used throughout the layer
num_blocks:
Number of structure module blocks
num_transition_layers:
Number of layers in the single representation transition (Alg. 23 lines 8-9)
num_resnet_blocks:
Number of blocks in the angle resnet
num_angles:
Number of angles to generate in the angle resnet
trans_scale_factor:
Scale of single representation transition hidden dimension
epsilon:
Small number used in angle resnet normalization
inf:
Large number used for attention masking
"""

sequence_dim: int = 384
pairwise_dim: int = 128
ipa_dim: int = 16
resnet_dim: int = 128
num_heads_ipa: int = 12
num_qk_points: int = 4
num_v_points: int = 8
dropout_rate: float = 0.1
num_blocks: int = 8
num_transition_layers: int = 1
num_resnet_blocks: int = 2
num_angles: int = 7
trans_scale_factor: int = 10
epsilon: float = 1e-8
inf: float = 1e5

def to_dict(self):
return asdict(self)


def get_default_vocab_list():
return (
"<cls>",
"<pad>",
"<eos>",
"<unk>",
"L",
"A",
"G",
"V",
"S",
"E",
"R",
"T",
"I",
"D",
"P",
"K",
"Q",
"N",
"F",
"Y",
"M",
"H",
"W",
"C",
"X",
"B",
"U",
"Z",
"O",
".",
"-",
"<null_1>",
"<mask>",
)

0 comments on commit a573853

Please sign in to comment.