diff --git a/docs/source/en/model_doc/esm.mdx b/docs/source/en/model_doc/esm.mdx index fef638c8b4f35b..9462e9db08773b 100644 --- a/docs/source/en/model_doc/esm.mdx +++ b/docs/source/en/model_doc/esm.mdx @@ -14,8 +14,8 @@ specific language governing permissions and limitations under the License. ## Overview This page provides code and pre-trained weights for Transformer protein language models from Meta AI's Fundamental -AI Research Team, providing the state-of-the-art ESM-2, and the previously released ESM-1b and ESM-1v. Transformer -protein language models were introduced in the paper [Biological structure and function emerge from scaling +AI Research Team, providing the state-of-the-art ESMFold and ESM-2, and the previously released ESM-1b and ESM-1v. +Transformer protein language models were introduced in the paper [Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences](https://www.pnas.org/content/118/15/e2016239118) by Alexander Rives, Joshua Meier, Tom Sercu, Siddharth Goyal, Zeming Lin, Jason Liu, Demi Guo, Myle Ott, C. Lawrence Zitnick, Jerry Ma, and Rob Fergus. @@ -27,6 +27,13 @@ It was released with the paper [Language models of protein sequences at the scal structure prediction](https://doi.org/10.1101/2022.07.20.500902) by Zeming Lin, Halil Akin, Roshan Rao, Brian Hie, Zhongkai Zhu, Wenting Lu, Allan dos Santos Costa, Maryam Fazel-Zarandi, Tom Sercu, Sal Candido and Alexander Rives. +Also introduced in this paper was ESMFold. It uses an ESM-2 stem with a head that can predict folded protein +structures with state-of-the-art accuracy. Unlike [AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2), +it relies on the token embeddings from the large pre-trained protein language model stem and does not perform a multiple +sequence alignment (MSA) step at inference time, which means that ESMFold checkpoints are fully "standalone" - +they do not require a database of known protein sequences and structures with associated external query tools +to make predictions, and are much faster as a result. + The abstract from "Biological structure and function emerge from scaling unsupervised learning to 250 @@ -63,17 +70,22 @@ order of magnitude faster than AlphaFold2, enabling exploration of the structura proteins in practical timescales.* - - Tips: - ESM models are trained with a masked language modeling (MLM) objective. The original code can be found [here](https://github.com/facebookresearch/esm) and was was developed by the Fundamental AI Research team at Meta AI. -This model was contributed to huggingface by [jasonliu](https://huggingface.co/jasonliu) +ESM-1b, ESM-1v and ESM-2 were contributed to huggingface by [jasonliu](https://huggingface.co/jasonliu) and [Matt](https://huggingface.co/Rocketknight1). +ESMFold was contributed to huggingface by [Matt](https://huggingface.co/Rocketknight1) and +[Sylvain](https://huggingface.co/sgugger), with a big thank you to Nikita Smetanin, Roshan Rao and Tom Sercu for their +help throughout the process! + +The HuggingFace port of ESMFold uses portions of the [openfold](https://github.com/aqlaboratory/openfold) library. +The `openfold` library is licensed under the Apache License 2.0. + ## EsmConfig [[autodoc]] EsmConfig @@ -108,6 +120,11 @@ and [Matt](https://huggingface.co/Rocketknight1). [[autodoc]] EsmForTokenClassification - forward +## EsmForProteinFolding + +[[autodoc]] EsmForProteinFolding + - forward + ## TFEsmModel [[autodoc]] TFEsmModel diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b30d8f719f3a06..e4d8e66de2a808 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1265,7 +1265,9 @@ _import_structure["models.esm"].extend( [ "ESM_PRETRAINED_MODEL_ARCHIVE_LIST", + "EsmFoldPreTrainedModel", "EsmForMaskedLM", + "EsmForProteinFolding", "EsmForSequenceClassification", "EsmForTokenClassification", "EsmModel", @@ -4144,7 +4146,9 @@ ) from .models.esm import ( ESM_PRETRAINED_MODEL_ARCHIVE_LIST, + EsmFoldPreTrainedModel, EsmForMaskedLM, + EsmForProteinFolding, EsmForSequenceClassification, EsmForTokenClassification, EsmModel, diff --git a/src/transformers/models/esm/__init__.py b/src/transformers/models/esm/__init__.py index b394b7a0c52ed1..0066ed2a3eb4b9 100644 --- a/src/transformers/models/esm/__init__.py +++ b/src/transformers/models/esm/__init__.py @@ -39,6 +39,7 @@ "EsmModel", "EsmPreTrainedModel", ] + _import_structure["modeling_esmfold"] = ["EsmForProteinFolding", "EsmFoldPreTrainedModel"] try: if not is_tf_available(): @@ -55,7 +56,6 @@ "TFEsmPreTrainedModel", ] - if TYPE_CHECKING: from .configuration_esm import ESM_PRETRAINED_CONFIG_ARCHIVE_MAP, EsmConfig from .tokenization_esm import EsmTokenizer @@ -74,6 +74,7 @@ EsmModel, EsmPreTrainedModel, ) + from .modeling_esmfold import EsmFoldPreTrainedModel, EsmForProteinFolding try: if not is_tf_available(): diff --git a/src/transformers/models/esm/configuration_esm.py b/src/transformers/models/esm/configuration_esm.py index 8f184c3e05847a..af6ee73c93c5b8 100644 --- a/src/transformers/models/esm/configuration_esm.py +++ b/src/transformers/models/esm/configuration_esm.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2021 Facebook and The HuggingFace Inc. team. All rights reserved. +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,12 +14,16 @@ # limitations under the License. """ ESM model configuration""" +from dataclasses import asdict, dataclass +from typing import Optional + from ...configuration_utils import PretrainedConfig from ...utils import logging logger = logging.get_logger(__name__) +# TODO Update this ESM_PRETRAINED_CONFIG_ARCHIVE_MAP = { "facebook/esm-1b": "https://huggingface.co/facebook/esm-1b/resolve/main/config.json", # See all ESM models at https://huggingface.co/models?filter=esm @@ -118,9 +122,12 @@ def __init__( classifier_dropout=None, emb_layer_norm_before=None, token_dropout=False, + is_folding_model=False, + esmfold_config=None, + vocab_list=None, **kwargs ): - super().__init__(pad_token_id=pad_token_id, **kwargs) + super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size @@ -138,5 +145,225 @@ def __init__( self.classifier_dropout = classifier_dropout self.emb_layer_norm_before = emb_layer_norm_before self.token_dropout = token_dropout - self.mask_token_id = mask_token_id - self.pad_token_id = pad_token_id + self.is_folding_model = is_folding_model + if is_folding_model: + if esmfold_config is None: + logger.info("No esmfold_config supplied for folding model, using default values.") + esmfold_config = EsmFoldConfig() + elif isinstance(esmfold_config, dict): + esmfold_config = EsmFoldConfig(**esmfold_config) + self.esmfold_config = esmfold_config + if vocab_list is None: + logger.warning("No vocab_list supplied for folding model, assuming the ESM-2 vocabulary!") + self.vocab_list = get_default_vocab_list() + else: + self.vocab_list = vocab_list + else: + self.esmfold_config = None + self.vocab_list = None + if self.esmfold_config is not None and getattr(self.esmfold_config, "use_esm_attn_map", False): + raise ValueError("The HuggingFace port of ESMFold does not support use_esm_attn_map at this time!") + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = super().to_dict() + if isinstance(self.esmfold_config, EsmFoldConfig): + output["esmfold_config"] = self.esmfold_config.to_dict() + return output + + +@dataclass +class EsmFoldConfig: + esm_type: str = None + fp16_esm: bool = True + use_esm_attn_map: bool = False + esm_ablate_pairwise: bool = False + esm_ablate_sequence: bool = False + esm_input_dropout: float = 0 + + embed_aa: bool = True + bypass_lm: bool = False + + lddt_head_hid_dim: int = 128 + trunk: "TrunkConfig" = None + + def __post_init__(self): + if self.trunk is None: + self.trunk = TrunkConfig() + elif isinstance(self.trunk, dict): + self.trunk = TrunkConfig(**self.trunk) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = asdict(self) + output["trunk"] = self.trunk.to_dict() + return output + + +@dataclass +class TrunkConfig: + num_blocks: int = 48 + sequence_state_dim: int = 1024 + pairwise_state_dim: int = 128 + sequence_head_width: int = 32 + pairwise_head_width: int = 32 + position_bins: int = 32 + dropout: float = 0 + layer_drop: float = 0 + cpu_grad_checkpoint: bool = False + max_recycles: int = 4 + chunk_size: Optional[int] = 128 + structure_module: "StructureModuleConfig" = None + + def __post_init__(self): + if self.structure_module is None: + self.structure_module = StructureModuleConfig() + elif isinstance(self.structure_module, dict): + self.structure_module = StructureModuleConfig(**self.structure_module) + + if self.max_recycles <= 0: + raise ValueError(f"`max_recycles` should be positive, got {self.max_recycles}.") + if self.sequence_state_dim % self.sequence_state_dim != 0: + raise ValueError( + "`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got" + f" {self.sequence_state_dim} and {self.sequence_state_dim}." + ) + if self.pairwise_state_dim % self.pairwise_state_dim != 0: + raise ValueError( + "`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got" + f" {self.pairwise_state_dim} and {self.pairwise_state_dim}." + ) + + sequence_num_heads = self.sequence_state_dim // self.sequence_head_width + pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width + + if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width: + raise ValueError( + "`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got" + f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}." + ) + if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width: + raise ValueError( + "`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got" + f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}." + ) + if self.pairwise_state_dim % 2 != 0: + raise ValueError(f"`pairwise_state_dim` should be even, got {self.pairwise_state_dim}.") + + if self.dropout >= 0.4: + raise ValueError(f"`dropout` should not be greater than 0.4, got {self.dropout}.") + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = asdict(self) + output["structure_module"] = self.structure_module.to_dict() + return output + + +@dataclass +class StructureModuleConfig: + """ + Args: + sequence_dim: + Single representation channel dimension + pairwise_dim: + Pair representation channel dimension + ipa_dim: + IPA hidden channel dimension + resnet_dim: + Angle resnet (Alg. 23 lines 11-14) hidden channel dimension + num_heads_ipa: + Number of IPA heads + num_qk_points: + Number of query/key points to generate during IPA + num_v_points: + Number of value points to generate during IPA + dropout_rate: + Dropout rate used throughout the layer + num_blocks: + Number of structure module blocks + num_transition_layers: + Number of layers in the single representation transition (Alg. 23 lines 8-9) + num_resnet_blocks: + Number of blocks in the angle resnet + num_angles: + Number of angles to generate in the angle resnet + trans_scale_factor: + Scale of single representation transition hidden dimension + epsilon: + Small number used in angle resnet normalization + inf: + Large number used for attention masking + """ + + sequence_dim: int = 384 + pairwise_dim: int = 128 + ipa_dim: int = 16 + resnet_dim: int = 128 + num_heads_ipa: int = 12 + num_qk_points: int = 4 + num_v_points: int = 8 + dropout_rate: float = 0.1 + num_blocks: int = 8 + num_transition_layers: int = 1 + num_resnet_blocks: int = 2 + num_angles: int = 7 + trans_scale_factor: int = 10 + epsilon: float = 1e-8 + inf: float = 1e5 + + def to_dict(self): + return asdict(self) + + +def get_default_vocab_list(): + return ( + "", + "", + "", + "", + "L", + "A", + "G", + "V", + "S", + "E", + "R", + "T", + "I", + "D", + "P", + "K", + "Q", + "N", + "F", + "Y", + "M", + "H", + "W", + "C", + "X", + "B", + "U", + "Z", + "O", + ".", + "-", + "", + "", + ) diff --git a/src/transformers/models/esm/convert_esm.py b/src/transformers/models/esm/convert_esm.py index 991f3594863762..3f92d863d4327f 100644 --- a/src/transformers/models/esm/convert_esm.py +++ b/src/transformers/models/esm/convert_esm.py @@ -23,7 +23,8 @@ import torch import esm as esm_module -from transformers.models.esm.configuration_esm import EsmConfig +from esm.esmfold.v1.pretrained import esmfold_v1 +from transformers.models.esm.configuration_esm import EsmConfig, EsmFoldConfig from transformers.models.esm.modeling_esm import ( EsmForMaskedLM, EsmForSequenceClassification, @@ -33,6 +34,7 @@ EsmSelfAttention, EsmSelfOutput, ) +from transformers.models.esm.modeling_esmfold import EsmForProteinFolding from transformers.models.esm.tokenization_esm import EsmTokenizer from transformers.utils import logging @@ -60,19 +62,51 @@ "esm2_t30_150M_UR50D": esm_module.pretrained.esm2_t30_150M_UR50D, "esm2_t12_35M_UR50D": esm_module.pretrained.esm2_t12_35M_UR50D, "esm2_t6_8M_UR50D": esm_module.pretrained.esm2_t6_8M_UR50D, + "esmfold_v1": esmfold_v1, } +def transfer_and_check_weights(original_module, our_module): + status = our_module.load_state_dict(original_module.state_dict()) + if status.missing_keys: + raise ValueError(f"Missing keys: {status.missing_keys}") + if status.unexpected_keys: + raise ValueError(f"Unexpected keys: {status.unexpected_keys}") + + def convert_esm_checkpoint_to_pytorch( model: str, pytorch_dump_folder_path: str, classification_head: bool, push_to_repo: str, auth_token: str ): """ Copy/paste/tweak esm's weights to our BERT structure. """ - esm, alphabet = MODEL_MAPPING[model]() + if model.startswith("esmfold"): + esm = MODEL_MAPPING[model]() + alphabet = esm.esm.alphabet + else: + esm, alphabet = MODEL_MAPPING[model]() esm.eval() # disable dropout - esm_sent_encoder = esm - if hasattr(esm, "args"): + + if model.startswith("esmfold"): + embed_dim = esm.esm.embed_dim + num_layers = esm.esm.num_layers + num_attention_heads = esm.esm.attention_heads + intermediate_size = 4 * embed_dim + token_dropout = esm.esm.token_dropout + emb_layer_norm_before = False # This code path does not exist in ESM-2 + position_embedding_type = "rotary" + is_folding_model = True + esmfold_config = EsmFoldConfig() + for key, val in esm.cfg.items(): + if hasattr(esmfold_config, key) and key != "trunk": + setattr(esmfold_config, key, val) + for key, val in esm.cfg.trunk.items(): + if hasattr(esmfold_config.trunk, key) and key != "structure_module": + setattr(esmfold_config.trunk, key, val) + for key, val in esm.cfg.trunk.structure_module.items(): + if hasattr(esmfold_config.trunk.structure_module, key): + setattr(esmfold_config.trunk.structure_module, key, val) + elif hasattr(esm, "args"): # Indicates an ESM-1b or ESM-1v model embed_dim = esm.args.embed_dim num_layers = esm.args.layers @@ -81,6 +115,8 @@ def convert_esm_checkpoint_to_pytorch( token_dropout = esm.args.token_dropout emb_layer_norm_before = True if esm.emb_layer_norm_before else False position_embedding_type = "absolute" + is_folding_model = False + esmfold_config = None else: # Indicates an ESM-2 model embed_dim = esm.embed_dim @@ -90,9 +126,18 @@ def convert_esm_checkpoint_to_pytorch( token_dropout = esm.token_dropout emb_layer_norm_before = False # This code path does not exist in ESM-2 position_embedding_type = "rotary" + is_folding_model = False + esmfold_config = None + + vocab_list = tuple(alphabet.all_toks) + + if is_folding_model: + original_esm_model = esm.esm + else: + original_esm_model = esm config = EsmConfig( - vocab_size=esm_sent_encoder.embed_tokens.num_embeddings, + vocab_size=original_esm_model.embed_tokens.num_embeddings, mask_token_id=alphabet.mask_idx, hidden_size=embed_dim, num_hidden_layers=num_layers, @@ -102,36 +147,45 @@ def convert_esm_checkpoint_to_pytorch( layer_norm_eps=1e-5, # PyTorch default used in fairseq attention_probs_dropout_prob=0.0, hidden_dropout_prob=0.0, - pad_token_id=esm.padding_idx, + pad_token_id=alphabet.padding_idx, emb_layer_norm_before=emb_layer_norm_before, token_dropout=token_dropout, position_embedding_type=position_embedding_type, + is_folding_model=is_folding_model, + esmfold_config=esmfold_config, + vocab_list=vocab_list, ) if classification_head: config.num_labels = esm.classification_heads["mnli"].out_proj.weight.shape[0] - print("Our BERT config:", config) + print("Our ESM config:", config) - model = EsmForSequenceClassification(config) if classification_head else EsmForMaskedLM(config) + if model.startswith("esmfold"): + model_class = EsmForProteinFolding + elif classification_head: + model_class = EsmForSequenceClassification + else: + model_class = EsmForMaskedLM + model = model_class(config) model.eval() # Now let's copy all the weights. # Embeddings - model.esm.embeddings.word_embeddings.weight = esm_sent_encoder.embed_tokens.weight + model.esm.embeddings.word_embeddings.weight = original_esm_model.embed_tokens.weight if position_embedding_type == "absolute": - model.esm.embeddings.position_embeddings.weight = esm_sent_encoder.embed_positions.weight + model.esm.embeddings.position_embeddings.weight = original_esm_model.embed_positions.weight if config.emb_layer_norm_before: - model.esm.embeddings.layer_norm.weight = esm_sent_encoder.emb_layer_norm_before.weight - model.esm.embeddings.layer_norm.bias = esm_sent_encoder.emb_layer_norm_before.bias + model.esm.embeddings.layer_norm.weight = original_esm_model.emb_layer_norm_before.weight + model.esm.embeddings.layer_norm.bias = original_esm_model.emb_layer_norm_before.bias - model.esm.encoder.emb_layer_norm_after.weight = esm_sent_encoder.emb_layer_norm_after.weight - model.esm.encoder.emb_layer_norm_after.bias = esm_sent_encoder.emb_layer_norm_after.bias + model.esm.encoder.emb_layer_norm_after.weight = original_esm_model.emb_layer_norm_after.weight + model.esm.encoder.emb_layer_norm_after.bias = original_esm_model.emb_layer_norm_after.bias for i in range(config.num_hidden_layers): # Encoder: start of layer layer: EsmLayer = model.esm.encoder.layer[i] - # esm_layer: TransformerSentenceEncoderLayer = esm_sent_encoder.layers[i] - esm_layer = esm_sent_encoder.layers[i] + # esm_layer: TransformerSentenceEncoderLayer = original_esm_model.layers[i] + esm_layer = original_esm_model.layers[i] # self attention self_attn: EsmSelfAttention = layer.attention.self @@ -183,7 +237,17 @@ def convert_esm_checkpoint_to_pytorch( bert_output.dense.bias = esm_layer.fc2.bias # end of layer - if classification_head: + if is_folding_model: + model.esm_s_combine.data = esm.esm_s_combine.data + transfer_and_check_weights(esm.embedding, model.embedding) + transfer_and_check_weights(esm.esm_s_mlp, model.esm_s_mlp) + transfer_and_check_weights(esm.trunk, model.trunk) + transfer_and_check_weights(esm.distogram_head, model.distogram_head) + transfer_and_check_weights(esm.ptm_head, model.ptm_head) + transfer_and_check_weights(esm.lm_head, model.lm_head) + transfer_and_check_weights(esm.lddt_head, model.lddt_head) + + elif classification_head: model.classifier.dense.weight = esm.esm.classification_heads["mnli"].dense.weight model.classifier.dense.bias = esm.classification_heads["mnli"].dense.bias model.classifier.out_proj.weight = esm.classification_heads["mnli"].out_proj.weight @@ -195,15 +259,19 @@ def convert_esm_checkpoint_to_pytorch( model.lm_head.layer_norm.weight = esm.lm_head.layer_norm.weight model.lm_head.layer_norm.bias = esm.lm_head.layer_norm.bias model.lm_head.decoder.weight = esm.lm_head.weight - model.lm_head.decoder.bias = esm.lm_head.bias + model.lm_head.bias = esm.lm_head.bias # Let's check that we get the same results. batch_converter = alphabet.get_batch_converter() # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4) + if is_folding_model: + # Folding models aren't trained on masked inputs and don't like mask tokens. + sample_data = SAMPLE_DATA[:2] + else: + sample_data = SAMPLE_DATA - batch_labels, batch_strs, batch_tokens = batch_converter(SAMPLE_DATA) - + batch_labels, batch_strs, batch_tokens = batch_converter(sample_data) # Prepare tokenizer and make sure it matches with TemporaryDirectory() as tempdir: vocab = "\n".join(alphabet.all_toks) @@ -211,32 +279,66 @@ def convert_esm_checkpoint_to_pytorch( vocab_file.write_text(vocab) hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file)) - hf_tokens = hf_tokenizer([row[1] for row in SAMPLE_DATA], return_tensors="pt", padding=True) + hf_tokens = hf_tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True) success = torch.all(hf_tokens["input_ids"] == batch_tokens) print("Do both models tokenizers output the same tokens?", "🔥" if success else "💩") if not success: raise Exception("Tokenization does not match!") with torch.no_grad(): - our_output = model(**hf_tokens, output_hidden_states=True) - our_output = our_output["logits"] - if classification_head: - their_output = esm.model.classification_heads["mnli"](esm.extract_features(batch_tokens)) + if is_folding_model: + # Let's test the model in parts + # ESMFold always converts the ESM stem to float16, which requires float16 ops + # that don't exist on CPU. Therefore, to test it we need to run it on GPU. However, + # ESMFold is what we in the community call a "big boy" and so we desperately avoid putting both the + # original and the converted model on the GPU at the same time. + our_output = model.cuda()( + input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda() + ) + their_output = esm.cuda()(hf_tokens["input_ids"].cuda(), hf_tokens["attention_mask"].cuda()) + else: + our_output = model(**hf_tokens, output_hidden_states=True) + our_output = our_output["logits"] + if classification_head: + their_output = esm.model.classification_heads["mnli"](esm.extract_features(batch_tokens)) + else: + their_output = esm(hf_tokens["input_ids"], repr_layers=list(range(999))) + their_output = their_output["logits"] + + if is_folding_model: + max_absolute_diff = torch.max(torch.abs(our_output["positions"] - their_output["positions"])).item() + success = torch.allclose(our_output["positions"], their_output["positions"], atol=1e-5) else: - their_output = esm(batch_tokens, repr_layers=list(range(999))) - their_output = their_output["logits"] - print(our_output.shape, their_output.shape) - max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item() - print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5 - success = torch.allclose(our_output, their_output, atol=3e-4) - print("Do both models output the same tensors?", "🔥" if success else "💩") + max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item() + success = torch.allclose(our_output, their_output, atol=1e-5) - if not success: - raise Exception("Something went wRoNg") + print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5 + print("Do both models output the same tensors?", "🔥" if success else "💩") + + if not success: + raise Exception("Something went wRoNg") + + pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True) + print(f"Saving model to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + + reloaded = model_class.from_pretrained(pytorch_dump_folder_path).cuda() + reloaded_output = reloaded( + input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda() + ) + + if is_folding_model: + max_absolute_diff = torch.max(torch.abs(our_output["positions"] - reloaded_output["positions"])).item() + success = torch.allclose(our_output["positions"], their_output["positions"], atol=1e-6) + else: + max_absolute_diff = torch.max(torch.abs(our_output - reloaded_output["logits"])).item() + success = torch.allclose(our_output, reloaded_output["logits"], atol=1e-6) + + print(f"max_absolute_diff = {max_absolute_diff}") + print("Does the model output the same tensors after reloading?", "🔥" if success else "💩") - pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True) - print(f"Saving model to {pytorch_dump_folder_path}") - model.save_pretrained(pytorch_dump_folder_path) + if not success: + raise Exception("Something went wRoNg") print(f"Saving tokenizer to {pytorch_dump_folder_path}") hf_tokenizer.save_pretrained(pytorch_dump_folder_path) diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 498c1cf737aaa7..09dd793f7215df 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -14,6 +14,7 @@ # limitations under the License. """ PyTorch ESM model.""" +import math from typing import List, Optional, Tuple, Union import torch @@ -21,7 +22,6 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...activations import ACT2FN, gelu from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -30,12 +30,7 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import logging from .configuration_esm import EsmConfig @@ -66,6 +61,13 @@ def apply_rotary_pos_emb(x, cos, sin): return (x * cos) + (rotate_half(x) * sin) +def gelu(x): + """ + This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results. + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + class RotaryEmbedding(torch.nn.Module): """ Rotary position embeddings based on those in @@ -163,7 +165,9 @@ def forward( mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs src_lengths = attention_mask.sum(-1) mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths - embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] + embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to( + embeddings.dtype + ) if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) @@ -172,7 +176,7 @@ def forward( if self.layer_norm is not None: embeddings = self.layer_norm(embeddings) if attention_mask is not None: - embeddings = embeddings * attention_mask.unsqueeze(-1) + embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) # Matt: I think this line was copied incorrectly from BERT, disabling it for now. # embeddings = self.dropout(embeddings) return embeddings @@ -398,19 +402,14 @@ def forward( return outputs -# Copied from transformers.models.bert.modeling_bert.BertIntermediate class EsmIntermediate(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] - else: - self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = gelu(hidden_states) return hidden_states @@ -497,15 +496,13 @@ def forward( cross_attn_present_key_value = cross_attention_outputs[-1] present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( - self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output - ) + layer_output = self.feed_forward_chunk(attention_output) + outputs = (layer_output,) + outputs # if decoder, return the attn key/values as the last output if self.is_decoder: outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): diff --git a/src/transformers/models/esm/modeling_esmfold.py b/src/transformers/models/esm/modeling_esmfold.py new file mode 100644 index 00000000000000..b32769ee226c76 --- /dev/null +++ b/src/transformers/models/esm/modeling_esmfold.py @@ -0,0 +1,2307 @@ +# coding=utf-8 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import sys +from dataclasses import dataclass +from functools import partial +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import LayerNorm + +from ...deepspeed import is_deepspeed_available +from ...modeling_outputs import ModelOutput +from ...utils import ( + ContextManagers, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_scipy_available, + logging, + replace_return_docstrings, +) +from .configuration_esm import EsmConfig +from .modeling_esm import ESM_START_DOCSTRING, EsmModel, EsmPreTrainedModel +from .openfold_utils import ( + OFProtein, + Rigid, + Rotation, + atom14_to_atom37, + chunk_layer, + compute_predicted_aligned_error, + compute_tm, + frames_and_literature_positions_to_atom14_pos, + make_atom14_masks, + residue_constants, + to_pdb, + torsion_angles_to_frames, +) + + +logger = logging.get_logger(__name__) + + +@dataclass +class EsmForProteinFoldingOutput(ModelOutput): + """ + Output type of [`EsmForProteinFoldingOutput`]. + + Args: + frames (`torch.FloatTensor`): + Output frames. + sidechain_frames (`torch.FloatTensor`): + Output sidechain frames. + unnormalized_angles (`torch.FloatTensor`): + Predicted unnormalized backbone and side chain torsion angles. + angles (`torch.FloatTensor`): + Predicted backbone and side chain torsion angles. + positions (`torch.FloatTensor`): + Predicted positions of the backbone and side chain atoms. + states (`torch.FloatTensor`): + Hidden states from the protein folding trunk. + s_s (`torch.FloatTensor`): + Per-residue embeddings derived by concatenating the hidden states of each layer of the ESM-2 LM stem. + s_z (`torch.FloatTensor`): + Pairwise residue embeddings. + distogram_logits (`torch.FloatTensor`): + Input logits to the distogram used to compute residue distances. + lm_logits (`torch.FloatTensor`): + Logits output by the ESM-2 protein language model stem. + aatype (`torch.FloatTensor`): + Input amino acids (AlphaFold2 indices). + atom14_atom_exists (`torch.FloatTensor`): + Whether each atom exists in the atom14 representation. + residx_atom14_to_atom37 (`torch.FloatTensor`): + Mapping between atoms in the atom14 and atom37 representations. + residx_atom37_to_atom14 (`torch.FloatTensor`): + Mapping between atoms in the atom37 and atom14 representations. + atom37_atom_exists (`torch.FloatTensor`): + Whether each atom exists in the atom37 representation. + residue_index (`torch.FloatTensor`): + The index of each residue in the protein chain. Unless internal padding tokens are used, this will just be + a sequence of integers from 0 to `sequence_length`. + lddt_head (`torch.FloatTensor`): + Raw outputs from the lddt head used to compute plddt. + plddt (`torch.FloatTensor`): + Per-residue confidence scores. Regions of low confidence may indicate areas where the model's prediction is + uncertain, or where the protein structure is disordered. + ptm_logits (`torch.FloatTensor`): + Raw logits used for computing ptm. + ptm (`torch.FloatTensor`): + TM-score output representing the model's high-level confidence in the overall structure. + aligned_confidence_probs (`torch.FloatTensor`): + Per-residue confidence scores for the aligned structure. + predicted_aligned_error (`torch.FloatTensor`): + Predicted error between the model's prediction and the ground truth. + max_predicted_aligned_error (`torch.FloatTensor`): + Per-sample maximum predicted error. + """ + + frames: torch.FloatTensor = None + sidechain_frames: torch.FloatTensor = None + unnormalized_angles: torch.FloatTensor = None + angles: torch.FloatTensor = None + positions: torch.FloatTensor = None + states: torch.FloatTensor = None + s_s: torch.FloatTensor = None + s_z: torch.FloatTensor = None + distogram_logits: torch.FloatTensor = None + lm_logits: torch.FloatTensor = None + aatype: torch.FloatTensor = None + atom14_atom_exists: torch.FloatTensor = None + residx_atom14_to_atom37: torch.FloatTensor = None + residx_atom37_to_atom14: torch.FloatTensor = None + atom37_atom_exists: torch.FloatTensor = None + residue_index: torch.FloatTensor = None + lddt_head: torch.FloatTensor = None + plddt: torch.FloatTensor = None + ptm_logits: torch.FloatTensor = None + ptm: torch.FloatTensor = None + aligned_confidence_probs: torch.FloatTensor = None + predicted_aligned_error: torch.FloatTensor = None + max_predicted_aligned_error: torch.FloatTensor = None + + +ESMFOLD_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`EsmTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + masking_pattern (`torch.LongTensor` of shape `({0})`, *optional*): + Locations of tokens to mask during training as a form of regularization. Mask values selected in `[0, 1]`. + num_recycles (`int`, *optional*, defaults to `None`): + Number of times to recycle the input sequence. If `None`, defaults to `config.num_recycles`. "Recycling" + consists of passing the output of the folding trunk back in as input to the trunk. During training, the + number of recycles should vary with each batch, to ensure that the model learns to output valid predictions + after each recycle. During inference, num_recycles should be set to the highest value that the model was + trained with for maximum accuracy. Accordingly, when this value is set to `None`, config.max_recycles is + used. +""" + + +def is_fp16_enabled(): + # Autocast world + fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16 + fp16_enabled = fp16_enabled and torch.is_autocast_enabled() + + return fp16_enabled + + +def is_deepspeed_initialized(): + if is_deepspeed_available(): + return False + else: + try: + import deepspeed + + # This is not available in all DeepSpeed versions. + return deepspeed.utils.is_initialized() + except Exception: + return False + + +def collate_dense_tensors(samples: List[torch.Tensor], pad_v: float = 0) -> torch.Tensor: + """ + Takes a list of tensors with the following dimensions: + [(d_11, ..., d_1K), + (d_21, ..., d_2K), ..., (d_N1, ..., d_NK)] + and stack + pads them into a single tensor of: + (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK}) + """ + if len(samples) == 0: + return torch.Tensor() + if len(set(x.dim() for x in samples)) != 1: + raise RuntimeError(f"Samples has varying dimensions: {[x.dim() for x in samples]}") + (device,) = tuple(set(x.device for x in samples)) # assumes all on same device + max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])] + result = torch.empty(len(samples), *max_shape, dtype=samples[0].dtype, device=device) + result.fill_(pad_v) + for i in range(len(samples)): + result_i = result[i] + t = samples[i] + result_i[tuple(slice(0, k) for k in t.shape)] = t + return result + + +def flatten_final_dims(t: torch.Tensor, no_dims: int): + return t.reshape(t.shape[:-no_dims] + (-1,)) + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +def dict_multimap(fn, dicts): + first = dicts[0] + new_dict = {} + for k, v in first.items(): + all_v = [d[k] for d in dicts] + if type(v) is dict: + new_dict[k] = dict_multimap(fn, all_v) + else: + new_dict[k] = fn(all_v) + + return new_dict + + +def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): + shape = weights.shape + scale = scale / max(1, shape[1]) + + if not is_scipy_available(): + logger.warning( + "This init requires scipy, but scipy was not found, default to an approximation that might not be" + " equivalent." + ) + std = math.sqrt(scale) + torch.nn.init.normal_(weights, std=std).clamp(min=0.0, max=2.0 * std) + + else: + from scipy.stats import truncnorm + + std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1) + samples = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=weights.numel()) + samples = np.reshape(samples, shape) + weights.copy_(torch.tensor(samples, device=weights.device)) + + +def ipa_point_weights_init_(weights): + with torch.no_grad(): + softplus_inverse_1 = 0.541324854612918 + weights.fill_(softplus_inverse_1) + + +class EsmFoldLinear(nn.Linear): + """ + A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear. + + Implements the initializers in 1.11.4, plus some additional ones found in the code. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = True, + init: str = "default", + init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, + ): + """ + Args: + in_dim: + The final dimension of inputs to the layer + out_dim: + The final dimension of layer outputs + bias: + Whether to learn an additive bias. True by default + init: + The initializer to use. Choose from: + + "default": LeCun fan-in truncated normal initialization "relu": He initialization w/ truncated normal + distribution "glorot": Fan-average Glorot uniform initialization "gating": Weights=0, Bias=1 "normal": + Normal initialization with std=1/sqrt(fan_in) "final": Weights=0, Bias=0 + + Overridden by init_fn if the latter is not None. + init_fn: + A custom initializer taking weight and bias as inputs. Overrides init if not None. + """ + super().__init__(in_dim, out_dim, bias=bias) + + if bias: + with torch.no_grad(): + self.bias.fill_(0) + self.init = init + self.init_fn = init_fn + + if init not in ["default", "relu", "glorot", "gating", "normal", "final"]: + raise ValueError("Invalid init string.") + + +class EsmFoldLayerNorm(nn.Module): + def __init__(self, c_in, eps=1e-5): + super().__init__() + + self.c_in = (c_in,) + self.eps = eps + + self.weight = nn.Parameter(torch.ones(c_in)) + self.bias = nn.Parameter(torch.zeros(c_in)) + + def forward(self, x): + d = x.dtype + if d is torch.bfloat16 and not is_deepspeed_initialized(): + with torch.cuda.amp.autocast(enabled=False): + out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps) + else: + out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps) + + return out + + +@torch.jit.ignore +def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Softmax, but without automatic casting to fp32 when the input is of type bfloat16 + """ + d = t.dtype + if d is torch.bfloat16 and not is_deepspeed_initialized(): + with torch.cuda.amp.autocast(enabled=False): + s = torch.nn.functional.softmax(t, dim=dim) + else: + s = torch.nn.functional.softmax(t, dim=dim) + + return s + + +class EsmFoldAttention(nn.Module): + """ + Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors. + """ + + def __init__( + self, + c_q: int, + c_k: int, + c_v: int, + c_hidden: int, + no_heads: int, + gating: bool = True, + ): + """ + Args: + c_q: + Input dimension of query data + c_k: + Input dimension of key data + c_v: + Input dimension of value data + c_hidden: + Per-head hidden dimension + no_heads: + Number of attention heads + gating: + Whether the output should be gated using query data + """ + super().__init__() + + self.c_q = c_q + self.c_k = c_k + self.c_v = c_v + self.c_hidden = c_hidden + self.no_heads = no_heads + self.gating = gating + + # DISCREPANCY: c_hidden is not the per-head channel dimension, as + # stated in the supplement, but the overall channel dimension. + + self.linear_q = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_k = EsmFoldLinear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_v = EsmFoldLinear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_o = EsmFoldLinear(self.c_hidden * self.no_heads, self.c_q, init="final") + + self.linear_g = None + if self.gating: + self.linear_g = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, init="gating") + + self.sigmoid = nn.Sigmoid() + + def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # [*, Q/K/V, H * C_hidden] + q = self.linear_q(q_x) + k = self.linear_k(kv_x) + v = self.linear_v(kv_x) + + # [*, Q/K, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + k = k.view(k.shape[:-1] + (self.no_heads, -1)) + v = v.view(v.shape[:-1] + (self.no_heads, -1)) + + # [*, H, Q/K, C_hidden] + q = q.transpose(-2, -3) + k = k.transpose(-2, -3) + v = v.transpose(-2, -3) + + q /= math.sqrt(self.c_hidden) + + return q, k, v + + def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor: + if self.linear_g is not None: + g = self.sigmoid(self.linear_g(q_x)) + + # [*, Q, H, C_hidden] + g = g.view(g.shape[:-1] + (self.no_heads, -1)) + o = o * g + + # [*, Q, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, Q, C_q] + o = self.linear_o(o) + + return o + + def forward( + self, + q_x: torch.Tensor, + kv_x: torch.Tensor, + biases: Optional[List[torch.Tensor]] = None, + use_memory_efficient_kernel: bool = False, + use_lma: bool = False, + lma_q_chunk_size: int = 1024, + lma_kv_chunk_size: int = 4096, + use_flash: bool = False, + flash_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + q_x: + [*, Q, C_q] query data + kv_x: + [*, K, C_k] key data + biases: + List of biases that broadcast to [*, H, Q, K] + use_memory_efficient_kernel: + Whether to use a custom memory-efficient attention kernel. This should be the default choice for most. + If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead + use_lma: + Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a + stock PyTorch implementation is used instead + lma_q_chunk_size: + Query chunk size (for LMA) + lma_kv_chunk_size: + Key/Value chunk size (for LMA) + Returns + [*, Q, C_q] attention update + """ + if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None): + raise ValueError("If use_lma is specified, lma_q_chunk_size and lma_kv_chunk_size must be provided") + + if use_flash and biases is not None: + raise ValueError("use_flash is incompatible with the bias option. For masking, use flash_mask instead") + + attn_options = [use_memory_efficient_kernel, use_lma, use_flash] + if sum(attn_options) > 1: + raise ValueError("Choose at most one alternative attention algorithm") + + if biases is None: + biases = [] + + # [*, H, Q/K, C_hidden] + query, key, value = self._prep_qkv(q_x, kv_x) + key = permute_final_dims(key, (1, 0)) + + # [*, H, Q, K] + output = torch.matmul(query, key) + for b in biases: + output += b + output = softmax_no_cast(output, -1) + + # [*, H, Q, C_hidden] + output = torch.matmul(output, value) + output = output.transpose(-2, -3) + output = self._wrap_up(output, q_x) + + return output + + +class EsmFoldTriangleAttention(nn.Module): + def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Overall hidden channel dimension (not per-head) + no_heads: + Number of attention heads + """ + super().__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.starting = starting + self.inf = inf + + self.layer_norm = LayerNorm(self.c_in) + + self.linear = EsmFoldLinear(c_in, self.no_heads, bias=False, init="normal") + + self.mha = EsmFoldAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads) + + @torch.jit.ignore + def _chunk( + self, + x: torch.Tensor, + biases: List[torch.Tensor], + chunk_size: int, + use_memory_efficient_kernel: bool = False, + use_lma: bool = False, + inplace_safe: bool = False, + ) -> torch.Tensor: + "triangle! triangle!" + mha_inputs = { + "q_x": x, + "kv_x": x, + "biases": biases, + } + + return chunk_layer( + partial(self.mha, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma), + mha_inputs, + chunk_size=chunk_size, + no_batch_dims=len(x.shape[:-2]), + _out=x if inplace_safe else None, + ) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + use_memory_efficient_kernel: bool = False, + use_lma: bool = False, + inplace_safe: bool = False, + ) -> torch.Tensor: + """ + Args: + x: + [*, I, J, C_in] input tensor (e.g. the pair representation) + Returns: + [*, I, J, C_in] output tensor + """ + if mask is None: + # [*, I, J] + mask = x.new_ones( + x.shape[:-1], + ) + + if not self.starting: + x = x.transpose(-2, -3) + mask = mask.transpose(-1, -2) + + # [*, I, J, C_in] + x = self.layer_norm(x) + + # [*, I, 1, 1, J] + mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] + + # [*, H, I, J] + triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) + + # [*, 1, H, I, J] + triangle_bias = triangle_bias.unsqueeze(-4) + + biases = [mask_bias, triangle_bias] + + if chunk_size is not None: + x = self._chunk( + x, + biases, + chunk_size, + use_memory_efficient_kernel=use_memory_efficient_kernel, + use_lma=use_lma, + inplace_safe=inplace_safe, + ) + else: + x = self.mha( + q_x=x, kv_x=x, biases=biases, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma + ) + + if not self.starting: + x = x.transpose(-2, -3) + + return x + + +class EsmFoldTriangleMultiplicativeUpdate(nn.Module): + """ + Implements Algorithms 11 and 12. + """ + + def __init__(self, config, _outgoing=True): + super().__init__() + c_hidden = config.pairwise_state_dim + self._outgoing = _outgoing + + self.linear_a_p = EsmFoldLinear(c_hidden, c_hidden) + self.linear_a_g = EsmFoldLinear(c_hidden, c_hidden, init="gating") + self.linear_b_p = EsmFoldLinear(c_hidden, c_hidden) + self.linear_b_g = EsmFoldLinear(c_hidden, c_hidden, init="gating") + self.linear_g = EsmFoldLinear(c_hidden, c_hidden, init="gating") + self.linear_z = EsmFoldLinear(c_hidden, c_hidden, init="final") + + self.layer_norm_in = LayerNorm(c_hidden) + self.layer_norm_out = LayerNorm(c_hidden) + + self.sigmoid = nn.Sigmoid() + + def _combine_projections( + self, a: torch.Tensor, b: torch.Tensor, _inplace_chunk_size: Optional[int] = None + ) -> torch.Tensor: + if self._outgoing: + a = permute_final_dims(a, (2, 0, 1)) + b = permute_final_dims(b, (2, 1, 0)) + else: + a = permute_final_dims(a, (2, 1, 0)) + b = permute_final_dims(b, (2, 0, 1)) + + if _inplace_chunk_size is not None: + # To be replaced by torch vmap + for i in range(0, a.shape[-3], _inplace_chunk_size): + a_chunk = a[..., i : i + _inplace_chunk_size, :, :] + b_chunk = b[..., i : i + _inplace_chunk_size, :, :] + a[..., i : i + _inplace_chunk_size, :, :] = torch.matmul( + a_chunk, + b_chunk, + ) + + p = a + else: + p = torch.matmul(a, b) + + return permute_final_dims(p, (1, 2, 0)) + + def _inference_forward( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + inplace_chunk_size: Optional[int] = None, + with_add: bool = True, + ): + """ + Args: + z: + A [*, N, N, C_z] pair representation + mask: + A [*, N, N] pair mask + inplace_chunk_size: + Size of chunks used in the main computation. Increase to trade memory for speed. + with_add: + If True, z is overwritten with (z + update). Otherwise, it is overwritten with (update). + Returns: + A reference to the overwritten z + + More memory-efficient, inference-only version of the forward function. Uses in-place operations, fusion of the + addition that happens after this module in the Evoformer, a smidge of recomputation, and a cache of overwritten + values to lower peak memory consumption of this module from 5x the size of the input tensor z to 2.5x its size. + Useful for inference on extremely long sequences. + + It works as follows. We will make reference to variables used in the default forward implementation below. + Naively, triangle multiplication attention requires the manifestation of 5 tensors the size of z: 1) z, the + "square" input tensor, 2) a, the first projection of z, 3) b, the second projection of b, 4) g, a z-sized mask, + and 5) a z-sized tensor for intermediate computations. For large N, this is prohibitively expensive; for + N=4000, for example, z is more than 8GB alone. To avoid this problem, we compute b, g, and all intermediate + tensors in small chunks, noting that the chunks required to compute a chunk of the output depend only on the + tensor a and corresponding vertical and horizontal chunks of z. This suggests an algorithm that loops over + pairs of chunks of z: hereafter "columns" and "rows" of z, even though each "column" and "row" in fact contains + inplace_chunk_size contiguous true columns and rows of z. Writing output chunks to a new tensor would bring + total memory consumption down to 3x the size of z. However, more memory can be saved by writing output chunks + directly to z in-place. WLOG, we choose to write output chunks vertically, overwriting the ith "column" of z at + the end of the ith iteration of the main loop. Despite this overwriting, the ith column is always one column + ahead of previously overwritten columns and can be recovered directly from z. After the first iteration, + however, the ith row of z is always at least partially overwritten. For this reason, we introduce the z-cache, + a tensor one-half the size of z. The z-cache initially contains the left half (2nd and 3rd quadrants) of z. For + 0 < i < N/2, the missing left part of the ith row of z is recovered from this cache at the beginning of the ith + iteration. Once i exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th quadrants of z instead. + Though the 3rd quadrant of the original z is entirely overwritten at this point, it can be recovered from the + z-cache itself. Thereafter, the ith row of z can be recovered in its entirety from the reoriented z-cache. + After the final iteration, z has been completely overwritten and contains the triangular multiplicative update. + If with_add is True, it instead contains the sum of z and the triangular multiplicative update. In either case, + peak memory consumption is just 2.5x the size of z, disregarding memory used for chunks and other small + variables. + """ + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + mask = mask.unsqueeze(-1) + + def compute_projection_helper(pair, mask, a=True): + if a: + linear_g = self.linear_a_g + linear_p = self.linear_a_p + else: + linear_g = self.linear_b_g + linear_p = self.linear_b_p + + pair = self.layer_norm_in(pair) + p = linear_g(pair) + p.sigmoid_() + p *= linear_p(pair) + p *= mask + p = permute_final_dims(p, (2, 0, 1)) + return p + + def compute_projection(pair, mask, a=True, chunked=True): + need_transpose = self._outgoing ^ a + if not chunked: + p = compute_projection_helper(pair, mask, a) + if need_transpose: + p = p.transpose(-1, -2) + else: + # This computation is chunked so as not to exceed our 2.5x + # budget with a large intermediate tensor + linear_g = self.linear_a_g if a else self.linear_b_g + c = linear_g.bias.shape[-1] + out_shape = pair.shape[:-3] + (c,) + pair.shape[-3:-1] + p = pair.new_zeros(out_shape) + for i in range(0, pair.shape[-3], inplace_chunk_size): + pair_chunk = pair[..., i : i + inplace_chunk_size, :, :] + pair_chunk = compute_projection_helper( + pair[..., i : i + inplace_chunk_size, :, :], + mask[..., i : i + inplace_chunk_size, :, :], + a, + ) + if need_transpose: + pair_chunk = pair_chunk.transpose(-1, -2) + p[..., i : i + inplace_chunk_size] = pair_chunk + else: + p[..., i : i + inplace_chunk_size, :] = pair_chunk + + del pair_chunk + + return p + + # We start by fully manifesting a. In addition to the input, this + # brings total memory consumption to 2x z (disregarding size of chunks) + # [*, N, N, c] + a = compute_projection(z, mask, True, chunked=True) + + if inplace_chunk_size is not None: + n = a.shape[-1] + half_n = n // 2 + n % 2 + row_dim = -3 + col_dim = -2 + b_chunk_dim = row_dim if self._outgoing else col_dim + + def empty_slicer(t): + return [slice(None) for _ in t.shape] + + def slice_tensor(t, start, end, dim): + # Slices start:end from the dim dimension of t + s = empty_slicer(t) + s[dim] = slice(start, end) + return t[s] + + def flip_z_cache_(z_cache, z): + # "Reorient" the z_cache (see below), filling it with quadrants + # 3---recovered from the z_cache---and 4---recovered from z--- + # of the input tensor z. + quadrant_3 = slice_tensor(z_cache, half_n, None, row_dim) + z_cache = z_cache.transpose(row_dim, col_dim) + + # If n is odd, we need to shrink the z_cache by one row + z_cache = z_cache[..., : (n // 2), :, :] + + # Move the 3rd quadrant of z into the + first_half_slicer = empty_slicer(z_cache) + first_half_slicer[col_dim] = slice(0, half_n) + z_cache[first_half_slicer] = quadrant_3 + + # Get the fourth quadrant of z + quadrant_4 = slice_tensor(z, half_n, None, row_dim) + quadrant_4 = slice_tensor(quadrant_4, half_n, None, col_dim) + + # Insert said quadrant into the rotated z-cache + quadrant_3_slicer = empty_slicer(z_cache) + quadrant_3_slicer[col_dim] = slice(half_n, None) + + z_cache[quadrant_3_slicer] = quadrant_4 + + return z_cache + + # Initialize the z cache to the left half of z. + z_cache_shape = list(z.shape) + z_cache_shape[col_dim] = half_n + z_cache = z.new_zeros(z_cache_shape) + z_cache_slicer = empty_slicer(z_cache) + z_cache_slicer[col_dim] = slice(0, half_n) + z_cache.copy_(z[z_cache_slicer]) + z_cache_rotated = False + + # We need to reorient the z-cache at the halfway point, and we + # don't want a single chunk to straddle that point. We contract one + # of the chunks in the middle to address that problem. + i_range = list(range(0, half_n, inplace_chunk_size)) + initial_offsets = [i_2 - i_1 for i_1, i_2 in zip(i_range, i_range[1:] + [half_n])] + after_half = list(range(half_n, n, inplace_chunk_size)) + after_half_offsets = [inplace_chunk_size for _ in after_half] + combined_range_with_offsets = zip(i_range + after_half, initial_offsets + after_half_offsets) + for i, offset in combined_range_with_offsets: + if not z_cache_rotated and i >= half_n: + z_cache = flip_z_cache_(z_cache, z) + z_cache_rotated = True + + z_chunk_b = slice_tensor(z, i, i + offset, b_chunk_dim) + mask_chunk = slice_tensor(mask, i, i + offset, b_chunk_dim) + + z_chunk_b = z_chunk_b.clone() + if b_chunk_dim == col_dim: + z_chunk_b = slice_tensor(z, i, i + offset, col_dim) + else: # b_chunk_dim == row_dim + # In this case, the b-dimension (b_chunk_dim) is partially + # overwritten at the end of each iteration. We need to + # restore the missing component from the z-cache. + if not z_cache_rotated: + z_chunk_slicer = empty_slicer(z_chunk_b) + z_chunk_slicer[col_dim] = slice(0, half_n) + z_chunk_b[z_chunk_slicer] = slice_tensor(z_cache, i, i + offset, row_dim) + else: + z_cache_offset = i - half_n + z_chunk_b = slice_tensor(z_cache, z_cache_offset, z_cache_offset + offset, row_dim) + + b_chunk = compute_projection(z_chunk_b, mask_chunk, a=False, chunked=False) + del z_chunk_b + + x_chunk = torch.matmul(a, b_chunk) + x_chunk = permute_final_dims(x_chunk, (1, 2, 0)) + x_chunk = self.layer_norm_out(x_chunk) + x_chunk = self.linear_z(x_chunk) + + # The g dimension (col_dim) is parallel to and ahead of the + # overwrites in z. We can extract the g chunk normally. + z_chunk_g = slice_tensor(z, i, i + offset, col_dim) + g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g)) + g_chunk.sigmoid_() + del z_chunk_g + + x_chunk *= g_chunk + + # Write the columns into z in-place + z_slicer = empty_slicer(z) + z_slicer[col_dim] = slice(i, i + offset) + if with_add: + z[z_slicer] += x_chunk + else: + z[z_slicer] = x_chunk + else: + b = compute_projection(z, mask, False, False) + x = torch.matmul(a, b) + x = self.layer_norm_out(x) + x = self.linear_z(x) + g = self.linear_g(z) + g.sigmoid_() + x *= g + if with_add: + z += x + else: + z = x + + return z + + def forward( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + inplace_safe: bool = False, + _add_with_inplace: bool = False, + _inplace_chunk_size: Optional[int] = 256, + ) -> torch.Tensor: + """ + Args: + x: + [*, N_res, N_res, C_z] input tensor + mask: + [*, N_res, N_res] input mask + Returns: + [*, N_res, N_res, C_z] output tensor + """ + if inplace_safe: + x = self._inference_forward( + z, + mask, + inplace_chunk_size=_inplace_chunk_size, + with_add=_add_with_inplace, + ) + return x + + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + mask = mask.unsqueeze(-1) + + z = self.layer_norm_in(z) + a = mask + a = a * self.sigmoid(self.linear_a_g(z)) + a = a * self.linear_a_p(z) + b = mask + b = b * self.sigmoid(self.linear_b_g(z)) + b = b * self.linear_b_p(z) + + if is_fp16_enabled(): + with torch.cuda.amp.autocast(enabled=False): + x = self._combine_projections(a.float(), b.float()) + else: + x = self._combine_projections(a, b) + + del a, b + x = self.layer_norm_out(x) + x = self.linear_z(x) + g = self.sigmoid(self.linear_g(z)) + x = x * g + + return x + + +class EsmFoldPreTrainedModel(EsmPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + # Subclass `EsMPreTrainedModel` to deal with special init + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, EsmFoldLinear): + with torch.no_grad(): + if module.init_fn is not None: + module.init_fn(module.weight, module.bias) + elif module.init == "default": + trunc_normal_init_(module.weight, scale=1.0) + elif module.init == "relu": + trunc_normal_init_(module.weight, scale=2.0) + elif module.init == "glorot": + nn.init.xavier_uniform_(module.weight, gain=1) + elif module.init == "gating": + module.weight.fill_(0.0) + if module.bias: + module.bias.fill_(1.0) + elif module.init == "normal": + torch.nn.init.kaiming_normal_(module.weight, nonlinearity="linear") + elif module.init == "final": + module.weight.fill_(0.0) + elif isinstance(module, EsmFoldInvariantPointAttention): + ipa_point_weights_init_(module.head_weights) + elif isinstance(module, EsmFoldTriangularSelfAttentionBlock): + torch.nn.init.zeros_(module.tri_mul_in.linear_z.weight) + torch.nn.init.zeros_(module.tri_mul_in.linear_z.bias) + torch.nn.init.zeros_(module.tri_mul_out.linear_z.weight) + torch.nn.init.zeros_(module.tri_mul_out.linear_z.bias) + torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.weight) + torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.bias) + torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.weight) + torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.bias) + + torch.nn.init.zeros_(module.sequence_to_pair.o_proj.weight) + torch.nn.init.zeros_(module.sequence_to_pair.o_proj.bias) + torch.nn.init.zeros_(module.pair_to_sequence.linear.weight) + torch.nn.init.zeros_(module.seq_attention.o_proj.weight) + torch.nn.init.zeros_(module.seq_attention.o_proj.bias) + torch.nn.init.zeros_(module.mlp_seq.mlp[-2].weight) + torch.nn.init.zeros_(module.mlp_seq.mlp[-2].bias) + torch.nn.init.zeros_(module.mlp_pair.mlp[-2].weight) + torch.nn.init.zeros_(module.mlp_pair.mlp[-2].bias) + else: + super()._init_weights(module) + + +class EsmFoldSelfAttention(nn.Module): + def __init__(self, embed_dim, num_heads, head_width, gated=False): + super().__init__() + assert embed_dim == num_heads * head_width + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_width = head_width + + self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False) + self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.gated = gated + if gated: + self.g_proj = nn.Linear(embed_dim, embed_dim) + torch.nn.init.zeros_(self.g_proj.weight) + torch.nn.init.ones_(self.g_proj.bias) + + self.rescale_factor = self.head_width**-0.5 + + torch.nn.init.zeros_(self.o_proj.bias) + + def forward(self, x, mask=None, bias=None, indices=None): + """ + Basic self attention with optional mask and external pairwise bias. To handle sequences of different lengths, + use mask. + + Inputs: + x: batch of input sequneces (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (.. + x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads) + + Outputs: + sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads) + """ + + t = self.proj(x).view(*x.shape[:2], self.num_heads, -1) + t = t.permute(0, 2, 1, 3) + q, k, v = t.chunk(3, dim=-1) + + q = self.rescale_factor * q + a = torch.einsum("...qc,...kc->...qk", q, k) + + # Add external attention bias. + if bias is not None: + a = a + bias.permute(0, 3, 1, 2) + + # Do not attend to padding tokens. + if mask is not None: + mask = mask[:, None, None] + a = a.masked_fill(mask == False, -np.inf) # noqa: E712 + + a = nn.functional.softmax(a, dim=-1) + + y = torch.einsum("...hqk,...hkc->...qhc", a, v) + y = y.reshape(*y.shape[:2], -1) + + if self.gated: + y = self.g_proj(x).sigmoid() * y + y = self.o_proj(y) + + return y, a.permute(0, 3, 1, 2) + + +class EsmFoldDropout(nn.Module): + """ + Implementation of dropout with the ability to share the dropout mask along a particular dimension. + """ + + def __init__(self, r: float, batch_dim: Union[int, List[int]]): + super().__init__() + + self.r = r + if type(batch_dim) == int: + batch_dim = [batch_dim] + self.batch_dim = batch_dim + self.dropout = nn.Dropout(self.r) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shape = list(x.shape) + if self.batch_dim is not None: + for bd in self.batch_dim: + shape[bd] = 1 + return x * self.dropout(x.new_ones(shape)) + + +class EsmFoldSequenceToPair(nn.Module): + def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim): + super().__init__() + + self.layernorm = nn.LayerNorm(sequence_state_dim) + self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True) + self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True) + + torch.nn.init.zeros_(self.proj.bias) + torch.nn.init.zeros_(self.o_proj.bias) + + def forward(self, sequence_state): + """ + Inputs: + sequence_state: B x L x sequence_state_dim + + Output: + pairwise_state: B x L x L x pairwise_state_dim + + Intermediate state: + B x L x L x 2*inner_dim + """ + + assert len(sequence_state.shape) == 3 + + s = self.layernorm(sequence_state) + s = self.proj(s) + q, k = s.chunk(2, dim=-1) + + prod = q[:, None, :, :] * k[:, :, None, :] + diff = q[:, None, :, :] - k[:, :, None, :] + + x = torch.cat([prod, diff], dim=-1) + x = self.o_proj(x) + + return x + + +class EsmFoldPairToSequence(nn.Module): + def __init__(self, pairwise_state_dim, num_heads): + super().__init__() + + self.layernorm = nn.LayerNorm(pairwise_state_dim) + self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False) + + def forward(self, pairwise_state): + """ + Inputs: + pairwise_state: B x L x L x pairwise_state_dim + + Output: + pairwise_bias: B x L x L x num_heads + """ + assert len(pairwise_state.shape) == 4 + z = self.layernorm(pairwise_state) + pairwise_bias = self.linear(z) + return pairwise_bias + + +class EsmFoldResidueMLP(nn.Module): + def __init__(self, embed_dim, inner_dim, dropout=0): + super().__init__() + + self.mlp = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, inner_dim), + nn.ReLU(), + nn.Linear(inner_dim, embed_dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return x + self.mlp(x) + + +class EsmFoldTriangularSelfAttentionBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + sequence_state_dim = config.sequence_state_dim + pairwise_state_dim = config.pairwise_state_dim + sequence_num_heads = sequence_state_dim // config.sequence_head_width + pairwise_num_heads = pairwise_state_dim // config.pairwise_head_width + + self.layernorm_1 = nn.LayerNorm(sequence_state_dim) + + self.sequence_to_pair = EsmFoldSequenceToPair(sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim) + self.pair_to_sequence = EsmFoldPairToSequence(pairwise_state_dim, sequence_num_heads) + + self.seq_attention = EsmFoldSelfAttention( + sequence_state_dim, sequence_num_heads, config.sequence_head_width, gated=True + ) + self.tri_mul_out = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=True) + self.tri_mul_in = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=False) + + self.tri_att_start = EsmFoldTriangleAttention( + pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=True + ) + self.tri_att_end = EsmFoldTriangleAttention( + pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=False + ) + + self.mlp_seq = EsmFoldResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=config.dropout) + self.mlp_pair = EsmFoldResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=config.dropout) + + self.drop = nn.Dropout(config.dropout) + self.row_drop = EsmFoldDropout(config.dropout * 2, 2) + self.col_drop = EsmFoldDropout(config.dropout * 2, 1) + + def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs): + """ + Inputs: + sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim mask: B x L boolean + tensor of valid positions + + Output: + sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim + """ + if len(sequence_state.shape) != 3: + raise ValueError(f"`sequence_state` should be a 3d-tensor, got {len(sequence_state.shape)} dims.") + if len(pairwise_state.shape) != 4: + raise ValueError(f"`pairwise_state` should be a 4d-tensor, got {len(pairwise_state.shape)} dims.") + if mask is not None and len(mask.shape) != 2: + raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.") + + batch_dim, seq_dim, sequence_state_dim = sequence_state.shape + pairwise_state_dim = pairwise_state.shape[3] + + if sequence_state_dim != self.config.sequence_state_dim: + raise ValueError( + "`sequence_state` last dimension should be equal to `self.sequence_state_dim`. Got" + f"{sequence_state_dim} != {self.config.sequence_state_dim}." + ) + if pairwise_state_dim != self.config.pairwise_state_dim: + raise ValueError( + "`pairwise_state` last dimension should be equal to `self.pairwise_state_dim`. Got " + f"{pairwise_state_dim} != {self.config.pairwise_state_dim}." + ) + if batch_dim != pairwise_state.shape[0]: + raise ValueError( + f"`sequence_state` and `pairwise_state` have inconsistent batch size: {batch_dim} != " + f"{pairwise_state.shape[0]}." + ) + if seq_dim != pairwise_state.shape[1] or seq_dim != pairwise_state.shape[2]: + raise ValueError( + f"`sequence_state` and `pairwise_state` have inconsistent sequence length: {seq_dim} != " + f"{pairwise_state.shape[1]} or {pairwise_state.shape[2]}." + ) + + # Update sequence state + bias = self.pair_to_sequence(pairwise_state) + + # Self attention with bias + mlp. + y = self.layernorm_1(sequence_state) + y, _ = self.seq_attention(y, mask=mask, bias=bias) + sequence_state = sequence_state + self.drop(y) + sequence_state = self.mlp_seq(sequence_state) + + # Update pairwise state + pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state) + + # Axial attention with triangular bias. + tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None + pairwise_state = pairwise_state + self.row_drop(self.tri_mul_out(pairwise_state, mask=tri_mask)) + pairwise_state = pairwise_state + self.col_drop(self.tri_mul_in(pairwise_state, mask=tri_mask)) + pairwise_state = pairwise_state + self.row_drop( + self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size) + ) + pairwise_state = pairwise_state + self.col_drop( + self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size) + ) + + # MLP over pairs. + pairwise_state = self.mlp_pair(pairwise_state) + + return sequence_state, pairwise_state + + +class EsmCategoricalMixture: + def __init__(self, param, bins=50, start=0, end=1): + # All tensors are of shape ..., bins. + self.logits = param + bins = torch.linspace(start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype) + self.v_bins = (bins[:-1] + bins[1:]) / 2 + + def log_prob(self, true): + # Shapes are: + # self.probs: ... x bins + # true : ... + true_index = (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1) + nll = self.logits.log_softmax(-1) + return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1) + + def mean(self): + return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1) + + +def categorical_lddt(logits, bins=50): + # Logits are ..., 37, bins. + return EsmCategoricalMixture(logits, bins=bins).mean() + + +def get_axial_mask(mask): + """ + Helper to convert B x L mask of valid positions to axial mask used in row column attentions. + + Input: + mask: B x L tensor of booleans + + Output: + mask: B x L x L tensor of booleans + """ + + if mask is None: + return None + + if len(mask.shape) != 2: + raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.") + batch_dim, seq_dim = mask.shape + m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim) + m = m.reshape(batch_dim * seq_dim, seq_dim) + return m + + +class EsmFoldRelativePosition(nn.Module): + def __init__(self, config): + super().__init__() + self.bins = config.position_bins + + # Note an additional offset is used so that the 0th position + # is reserved for masked pairs. + self.embedding = torch.nn.Embedding(2 * self.bins + 2, config.pairwise_state_dim) + + def forward(self, residue_index, mask=None): + """ + Input: + residue_index: B x L tensor of indices (dytpe=torch.long) mask: B x L tensor of booleans + + Output: + pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings + """ + if residue_index.dtype != torch.long: + raise ValueError(f"`residue_index` has dtype {residue_index.dtype}, it should be `torch.long`.") + if mask is not None and residue_index.shape != mask.shape: + raise ValueError( + f"`residue_index` and `mask` have inconsistent shapes: {residue_index.shape} != {mask.shape}." + ) + + diff = residue_index[:, None, :] - residue_index[:, :, None] + diff = diff.clamp(-self.bins, self.bins) + diff = diff + self.bins + 1 # Add 1 to adjust for padding index. + + if mask is not None: + mask = mask[:, None, :] * mask[:, :, None] + diff[mask == False] = 0 # noqa: E712 + + output = self.embedding(diff) + return output + + +class EsmFoldAngleResnetBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.linear_1 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="relu") + self.linear_2 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="final") + + self.relu = nn.ReLU() + + def forward(self, a: torch.Tensor) -> torch.Tensor: + s_initial = a + + a = self.relu(a) + a = self.linear_1(a) + a = self.relu(a) + a = self.linear_2(a) + + return a + s_initial + + +class EsmFoldAngleResnet(nn.Module): + """ + Implements Algorithm 20, lines 11-14 + """ + + def __init__(self, config): + super().__init__() + self.config = config + + self.linear_in = EsmFoldLinear(config.sequence_dim, config.resnet_dim) + self.linear_initial = EsmFoldLinear(config.sequence_dim, config.resnet_dim) + + self.layers = nn.ModuleList() + for _ in range(config.num_resnet_blocks): + layer = EsmFoldAngleResnetBlock(config) + self.layers.append(layer) + + self.linear_out = EsmFoldLinear(config.resnet_dim, config.num_angles * 2) + + self.relu = nn.ReLU() + + def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + s: + [*, C_hidden] single embedding + s_initial: + [*, C_hidden] single embedding as of the start of the StructureModule + Returns: + [*, no_angles, 2] predicted angles + """ + # NOTE: The ReLU's applied to the inputs are absent from the supplement + # pseudocode but present in the source. For maximal compatibility with + # the pretrained weights, I'm going with the source. + + # [*, C_hidden] + s_initial = self.relu(s_initial) + s_initial = self.linear_initial(s_initial) + s = self.relu(s) + s = self.linear_in(s) + s = s + s_initial + + for l in self.layers: + s = l(s) + + s = self.relu(s) + + # [*, no_angles * 2] + s = self.linear_out(s) + + # [*, no_angles, 2] + s = s.view(s.shape[:-1] + (-1, 2)) + + unnormalized_s = s + norm_denom = torch.sqrt( + torch.clamp( + torch.sum(s**2, dim=-1, keepdim=True), + min=self.config.epsilon, + ) + ) + s = s / norm_denom + + return unnormalized_s, s + + +class EsmFoldInvariantPointAttention(nn.Module): + """ + Implements Algorithm 22. + """ + + def __init__(self, config): + super().__init__() + self.config = config + + c_s = config.sequence_dim + c_z = config.pairwise_dim + self.hidden_dim = config.ipa_dim + self.num_heads = config.num_heads_ipa + self.num_qk_points = config.num_qk_points + self.num_v_points = config.num_v_points + + # These linear layers differ from their specifications in the + # supplement. There, they lack bias and use Glorot initialization. + # Here as in the official source, they have bias and use the default + # Lecun initialization. + hc = config.ipa_dim * config.num_heads_ipa + self.linear_q = EsmFoldLinear(c_s, hc) + self.linear_kv = EsmFoldLinear(c_s, 2 * hc) + + hpq = config.num_heads_ipa * config.num_qk_points * 3 + self.linear_q_points = EsmFoldLinear(c_s, hpq) + + hpkv = config.num_heads_ipa * (config.num_qk_points + config.num_v_points) * 3 + self.linear_kv_points = EsmFoldLinear(c_s, hpkv) + + self.linear_b = EsmFoldLinear(c_z, config.num_heads_ipa) + + self.head_weights = nn.Parameter(torch.zeros((config.num_heads_ipa))) + + concat_out_dim = config.num_heads_ipa * (c_z + config.ipa_dim + config.num_v_points * 4) + self.linear_out = EsmFoldLinear(concat_out_dim, c_s, init="final") + + self.softmax = nn.Softmax(dim=-1) + self.softplus = nn.Softplus() + + def forward( + self, + s: torch.Tensor, + z: Optional[torch.Tensor], + r: Rigid, + mask: torch.Tensor, + _offload_inference: bool = False, + _z_reference_list: Optional[Sequence[torch.Tensor]] = None, + ) -> torch.Tensor: + """ + Args: + s: + [*, N_res, C_s] single representation + z: + [*, N_res, N_res, C_z] pair representation + r: + [*, N_res] transformation object + mask: + [*, N_res] mask + Returns: + [*, N_res, C_s] single representation update + """ + z = [z] + + ####################################### + # Generate scalar and point activations + ####################################### + # [*, N_res, H * C_hidden] + q = self.linear_q(s) + kv = self.linear_kv(s) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.num_heads, -1)) + + # [*, N_res, H, 2 * C_hidden] + kv = kv.view(kv.shape[:-1] + (self.num_heads, -1)) + + # [*, N_res, H, C_hidden] + k, v = torch.split(kv, self.hidden_dim, dim=-1) + + # [*, N_res, H * P_q * 3] + q_pts = self.linear_q_points(s) + + # This is kind of clunky, but it's how the original does it + # [*, N_res, H * P_q, 3] + q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) + q_pts = torch.stack(q_pts, dim=-1) + q_pts = r[..., None].apply(q_pts) + + # [*, N_res, H, P_q, 3] + q_pts = q_pts.view(q_pts.shape[:-2] + (self.num_heads, self.num_qk_points, 3)) + + # [*, N_res, H * (P_q + P_v) * 3] + kv_pts = self.linear_kv_points(s) + + # [*, N_res, H * (P_q + P_v), 3] + kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) + kv_pts = torch.stack(kv_pts, dim=-1) + kv_pts = r[..., None].apply(kv_pts) + + # [*, N_res, H, (P_q + P_v), 3] + kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3)) + + # [*, N_res, H, P_q/P_v, 3] + k_pts, v_pts = torch.split(kv_pts, [self.num_qk_points, self.num_v_points], dim=-2) + + ########################## + # Compute attention scores + ########################## + # [*, N_res, N_res, H] + b = self.linear_b(z[0]) + + if _offload_inference: + assert sys.getrefcount(z[0]) == 2 + z[0] = z[0].cpu() + + # [*, H, N_res, N_res] + if is_fp16_enabled(): + with torch.cuda.amp.autocast(enabled=False): + a = torch.matmul( + permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + else: + a = torch.matmul( + permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + + a *= math.sqrt(1.0 / (3 * self.hidden_dim)) + a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)) + + # [*, N_res, N_res, H, P_q, 3] + pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) + pt_att = pt_att**2 + + # [*, N_res, N_res, H, P_q] + pt_att = sum(torch.unbind(pt_att, dim=-1)) + head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1))) + head_weights = head_weights * math.sqrt(1.0 / (3 * (self.num_qk_points * 9.0 / 2))) + pt_att = pt_att * head_weights + + # [*, N_res, N_res, H] + pt_att = torch.sum(pt_att, dim=-1) * (-0.5) + # [*, N_res, N_res] + square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + square_mask = self.config.inf * (square_mask - 1) + + # [*, H, N_res, N_res] + pt_att = permute_final_dims(pt_att, (2, 0, 1)) + + a = a + pt_att + a = a + square_mask.unsqueeze(-3) + a = self.softmax(a) + + ################ + # Compute output + ################ + # [*, N_res, H, C_hidden] + o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3) + + # [*, N_res, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, H, 3, N_res, P_v] + o_pt = torch.sum( + (a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]), + dim=-2, + ) + + # [*, N_res, H, P_v, 3] + o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) + o_pt = r[..., None, None].invert_apply(o_pt) + + # [*, N_res, H * P_v] + o_pt_norm = flatten_final_dims(torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.config.epsilon), 2) + + # [*, N_res, H * P_v, 3] + o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) + + if _offload_inference: + z[0] = z[0].to(o_pt.device) + + # [*, N_res, H, C_z] + o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype)) + + # [*, N_res, H * C_z] + o_pair = flatten_final_dims(o_pair, 2) + + # [*, N_res, C_s] + s = self.linear_out( + torch.cat((o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype) + ) + + return s + + +class EsmFoldBackboneUpdate(nn.Module): + """ + Implements part of Algorithm 23. + """ + + def __init__(self, config): + super().__init__() + + self.linear = EsmFoldLinear(config.sequence_dim, 6, init="final") + + def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + [*, N_res, C_s] single representation + Returns: + [*, N_res, 6] update vector + """ + # [*, 6] + update = self.linear(s) + + return update + + +class EsmFoldStructureModuleTransitionLayer(nn.Module): + def __init__(self, config): + super().__init__() + + self.linear_1 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu") + self.linear_2 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu") + self.linear_3 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="final") + + self.relu = nn.ReLU() + + def forward(self, s): + s_initial = s + s = self.linear_1(s) + s = self.relu(s) + s = self.linear_2(s) + s = self.relu(s) + s = self.linear_3(s) + + s = s + s_initial + + return s + + +class EsmFoldStructureModuleTransition(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.layers = nn.ModuleList() + for _ in range(config.num_transition_layers): + l = EsmFoldStructureModuleTransitionLayer(config) + self.layers.append(l) + + self.dropout = nn.Dropout(config.dropout_rate) + self.layer_norm = LayerNorm(config.sequence_dim) + + def forward(self, s): + for l in self.layers: + s = l(s) + + s = self.dropout(s) + s = self.layer_norm(s) + + return s + + +class EsmFoldStructureModule(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + # Buffers to be lazily initialized later + # self.default_frames + # self.group_idx + # self.atom_mask + # self.lit_positions + + self.layer_norm_s = LayerNorm(config.sequence_dim) + self.layer_norm_z = LayerNorm(config.pairwise_dim) + + self.linear_in = EsmFoldLinear(config.sequence_dim, config.sequence_dim) + + self.ipa = EsmFoldInvariantPointAttention(config) + + self.ipa_dropout = nn.Dropout(config.dropout_rate) + self.layer_norm_ipa = LayerNorm(config.sequence_dim) + + self.transition = EsmFoldStructureModuleTransition(config) + self.bb_update = EsmFoldBackboneUpdate(config) + self.angle_resnet = EsmFoldAngleResnet(config) + + def forward( + self, + evoformer_output_dict, + aatype, + mask=None, + _offload_inference=False, + ): + """ + Args: + evoformer_output_dict: + Dictionary containing: + "single": + [*, N_res, C_s] single representation + "pair": + [*, N_res, N_res, C_z] pair representation + aatype: + [*, N_res] amino acid indices + mask: + Optional [*, N_res] sequence mask + Returns: + A dictionary of outputs + """ + s = evoformer_output_dict["single"] + + if mask is None: + # [*, N] + mask = s.new_ones(s.shape[:-1]) + + # [*, N, C_s] + s = self.layer_norm_s(s) + + # [*, N, N, C_z] + z = self.layer_norm_z(evoformer_output_dict["pair"]) + + z_reference_list = None + if _offload_inference: + assert sys.getrefcount(evoformer_output_dict["pair"]) == 2 + evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu() + z_reference_list = [z] + z = None + + # [*, N, C_s] + s_initial = s + s = self.linear_in(s) + + # [*, N] + rigids = Rigid.identity( + s.shape[:-1], + s.dtype, + s.device, + self.training, + fmt="quat", + ) + outputs = [] + for i in range(self.config.num_blocks): + # [*, N, C_s] + s = s + self.ipa( + s, + z, + rigids, + mask, + _offload_inference=_offload_inference, + _z_reference_list=z_reference_list, + ) + s = self.ipa_dropout(s) + s = self.layer_norm_ipa(s) + s = self.transition(s) + + # [*, N] + rigids = rigids.compose_q_update_vec(self.bb_update(s)) + + # To hew as closely as possible to AlphaFold, we convert our + # quaternion-based transformations to rotation-matrix ones + # here + backb_to_global = Rigid( + Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None), + rigids.get_trans(), + ) + + backb_to_global = backb_to_global.scale_translation(self.config.trans_scale_factor) + + # [*, N, 7, 2] + unnormalized_angles, angles = self.angle_resnet(s, s_initial) + + all_frames_to_global = self.torsion_angles_to_frames(backb_to_global, angles, aatype) + + pred_xyz = self.frames_and_literature_positions_to_atom14_pos(all_frames_to_global, aatype) + + scaled_rigids = rigids.scale_translation(self.config.trans_scale_factor) + + preds = { + "frames": scaled_rigids.to_tensor_7(), + "sidechain_frames": all_frames_to_global.to_tensor_4x4(), + "unnormalized_angles": unnormalized_angles, + "angles": angles, + "positions": pred_xyz, + "states": s, + } + + outputs.append(preds) + + rigids = rigids.stop_rot_gradient() + + del z, z_reference_list + + if _offload_inference: + evoformer_output_dict["pair"] = evoformer_output_dict["pair"].to(s.device) + + outputs = dict_multimap(torch.stack, outputs) + outputs["single"] = s + + return outputs + + def _init_residue_constants(self, float_dtype, device): + if not hasattr(self, "default_frames"): + self.register_buffer( + "default_frames", + torch.tensor( + residue_constants.restype_rigid_group_default_frame, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "group_idx"): + self.register_buffer( + "group_idx", + torch.tensor( + residue_constants.restype_atom14_to_rigid_group, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "atom_mask"): + self.register_buffer( + "atom_mask", + torch.tensor( + residue_constants.restype_atom14_mask, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "lit_positions"): + self.register_buffer( + "lit_positions", + torch.tensor( + residue_constants.restype_atom14_rigid_group_positions, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + + def torsion_angles_to_frames(self, r, alpha, f): + # Lazily initialize the residue constants on the correct device + self._init_residue_constants(alpha.dtype, alpha.device) + # Separated purely to make testing less annoying + return torsion_angles_to_frames(r, alpha, f, self.default_frames) + + def frames_and_literature_positions_to_atom14_pos(self, r, f): # [*, N, 8] # [*, N] + # Lazily initialize the residue constants on the correct device + self._init_residue_constants(r.get_rots().dtype, r.get_rots().device) + return frames_and_literature_positions_to_atom14_pos( + r, + f, + self.default_frames, + self.group_idx, + self.atom_mask, + self.lit_positions, + ) + + +class EsmFoldingTrunk(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + c_s = config.sequence_state_dim + c_z = config.pairwise_state_dim + + self.pairwise_positional_embedding = EsmFoldRelativePosition(config) + + self.blocks = nn.ModuleList([EsmFoldTriangularSelfAttentionBlock(config) for _ in range(config.num_blocks)]) + + self.recycle_bins = 15 + self.recycle_s_norm = nn.LayerNorm(c_s) + self.recycle_z_norm = nn.LayerNorm(c_z) + self.recycle_disto = nn.Embedding(self.recycle_bins, c_z) + self.recycle_disto.weight[0].detach().zero_() + + self.structure_module = EsmFoldStructureModule(config.structure_module) + self.trunk2sm_s = nn.Linear(c_s, config.structure_module.sequence_dim) + self.trunk2sm_z = nn.Linear(c_z, config.structure_module.pairwise_dim) + + self.chunk_size = config.chunk_size + + def set_chunk_size(self, chunk_size): + # This parameter means the axial attention will be computed + # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2). + # It's equivalent to running a for loop over chunks of the dimension we're iterative over, + # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-lengthed chunks. + self.chunk_size = chunk_size + + def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles): + """ + Inputs: + seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B + x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues + + Output: + predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object + """ + + device = seq_feats.device + s_s_0 = seq_feats + s_z_0 = pair_feats + + if no_recycles is None: + no_recycles = self.config.max_recycles + else: + if no_recycles < 0: + raise ValueError("Number of recycles must not be negative.") + no_recycles += 1 # First 'recycle' is just the standard forward pass through the model. + + def trunk_iter(s, z, residx, mask): + z = z + self.pairwise_positional_embedding(residx, mask=mask) + + for block in self.blocks: + s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size) + return s, z + + s_s = s_s_0 + s_z = s_z_0 + recycle_s = torch.zeros_like(s_s) + recycle_z = torch.zeros_like(s_z) + recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64) + + for recycle_idx in range(no_recycles): + with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]): + # === Recycling === + recycle_s = self.recycle_s_norm(recycle_s.detach()) + recycle_z = self.recycle_z_norm(recycle_z.detach()) + recycle_z += self.recycle_disto(recycle_bins.detach()) + + s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask) + + # === Structure module === + structure = self.structure_module( + {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)}, + true_aa, + mask.float(), + ) + + recycle_s = s_s + recycle_z = s_z + # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold. + recycle_bins = EsmFoldingTrunk.distogram( + structure["positions"][-1][:, :, :3], + 3.375, + 21.375, + self.recycle_bins, + ) + + structure["s_s"] = s_s + structure["s_z"] = s_z + + return structure + + @staticmethod + def distogram(coords, min_bin, max_bin, num_bins): + # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates. + boundaries = torch.linspace( + min_bin, + max_bin, + num_bins - 1, + device=coords.device, + ) + boundaries = boundaries**2 + N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)] + # Infer CB coordinates. + b = CA - N + c = C - CA + a = b.cross(c, dim=-1) + CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA + dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True) + bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L] + return bins + + +# TODO Add information to the docstring about any methods that convert to PDB format, or otherwise prepare +# the outputs for downstream use. + + +@add_start_docstrings( + """ + ESMForProteinFolding is the HuggingFace port of the original ESMFold model. It consists of an ESM-2 "stem" followed + by a protein folding "head", although unlike most other output heads, this "head" is similar in size and runtime to + the rest of the model combined! It outputs a dictionary containing predicted structural information about the input + protein(s). + """, + ESM_START_DOCSTRING, +) +class EsmForProteinFolding(EsmPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.config = config + + self.distogram_bins = 64 + + self.esm = EsmModel(config, add_pooling_layer=False) + + self.esm.requires_grad_(False) + if self.config.esmfold_config.fp16_esm: + self.esm.half() + + self.esm_feats = self.config.hidden_size + self.esm_attns = self.config.num_hidden_layers * self.config.num_attention_heads + self.esm_layers = self.config.num_hidden_layers + self.register_buffer("af2_to_esm", self._af2_to_esm_from_vocab_list(config.vocab_list)) + self.esm_s_combine = nn.Parameter(torch.zeros(self.esm_layers + 1)) + + trunk_config = self.config.esmfold_config.trunk + c_s = trunk_config.sequence_state_dim + c_z = trunk_config.pairwise_state_dim + self.esm_s_mlp = nn.Sequential( + LayerNorm(self.esm_feats), + nn.Linear(self.esm_feats, c_s), + nn.ReLU(), + nn.Linear(c_s, c_s), + ) + + # 0 is padding, N is unknown residues, N + 1 is mask. + self.n_tokens_embed = residue_constants.restype_num + 3 + self.pad_idx = 0 + self.unk_idx = self.n_tokens_embed - 2 + self.mask_idx = self.n_tokens_embed - 1 + self.esm_dict_cls_idx = self.config.vocab_list.index("") + self.esm_dict_mask_idx = self.config.vocab_list.index("") + self.esm_dict_eos_idx = self.config.vocab_list.index("") + self.esm_dict_padding_idx = self.config.vocab_list.index("") + if self.config.esmfold_config.embed_aa: + self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0) + + self.trunk = EsmFoldingTrunk(trunk_config) + + self.distogram_head = nn.Linear(c_z, self.distogram_bins) + self.ptm_head = nn.Linear(c_z, self.distogram_bins) + self.lm_head = nn.Linear(c_s, self.n_tokens_embed) + self.lddt_bins = 50 + structure_module_config = trunk_config.structure_module + self.lddt_head = nn.Sequential( + nn.LayerNorm(structure_module_config.sequence_dim), + nn.Linear(structure_module_config.sequence_dim, self.config.esmfold_config.lddt_head_hid_dim), + nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim), + nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins), + ) + + @staticmethod + def _af2_to_esm_from_vocab_list(vocab_list: List[str]) -> torch.Tensor: + # Remember that t is shifted from residue_constants by 1 (0 is padding). + esm_reorder = [vocab_list.index("")] + [vocab_list.index(v) for v in residue_constants.restypes_with_x] + return torch.tensor(esm_reorder) + + @add_start_docstrings_to_model_forward(ESMFOLD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=EsmForProteinFoldingOutput, config_class=EsmConfig) + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor = None, + position_ids: Optional[torch.Tensor] = None, + masking_pattern: Optional[torch.Tensor] = None, + num_recycles: Optional[int] = None, + ): + r""" + Returns: + + Example: + + TODO Matt + """ + cfg = self.config.esmfold_config + + aa = input_ids # B x L + B = aa.shape[0] + L = aa.shape[1] + device = input_ids.device + if attention_mask is None: + attention_mask = torch.ones_like(aa, device=device) + if position_ids is None: + position_ids = torch.arange(L, device=device).expand_as(input_ids) + + # === ESM === + esmaa = self.af2_idx_to_esm_idx(aa, attention_mask) + + if masking_pattern is not None: + masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, attention_mask, masking_pattern) + else: + masked_aa = aa + mlm_targets = None + + # We get sequence and pair representations from whatever version of ESM / + # configuration we are using. The sequence representation esm_s is always + # present. The pair embedding esm_z may be present depending on the + # configuration of the model. If esm_z is not used by the model then it + # is returned as None here. + esm_s = self.compute_language_model_representations(esmaa) + + # Convert esm_s and esm_z, if present, to the precision used by the trunk and + # the structure module. These tensors may be a lower precision if, for example, + # we're running the language model in fp16 precision. + esm_s = esm_s.to(self.esm_s_combine.dtype) + + if cfg.esm_ablate_sequence: + esm_s = esm_s * 0 + + esm_s = esm_s.detach() + + # === preprocessing === + esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2) + s_s_0 = self.esm_s_mlp(esm_s) + + s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk.pairwise_state_dim) + + if self.config.esmfold_config.embed_aa: + s_s_0 += self.embedding(masked_aa) + + structure: dict = self.trunk(s_s_0, s_z_0, aa, position_ids, attention_mask, no_recycles=num_recycles) + # Documenting what we expect: + structure = { + k: v + for k, v in structure.items() + if k + in [ + "s_z", + "s_s", + "frames", + "sidechain_frames", + "unnormalized_angles", + "angles", + "positions", + "states", + ] + } + + # Add BERT mask for the loss to use, if available. + if mlm_targets: + structure["mlm_targets"] = mlm_targets + + disto_logits = self.distogram_head(structure["s_z"]) + disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2 + structure["distogram_logits"] = disto_logits + + lm_logits = self.lm_head(structure["s_s"]) + structure["lm_logits"] = lm_logits + + structure["aatype"] = aa + make_atom14_masks(structure) + # Of course, this doesn't respect the true mask because it doesn't know about it... + # We're not going to properly mask change of index tensors: + # "residx_atom14_to_atom37", + # "residx_atom37_to_atom14", + for k in [ + "atom14_atom_exists", + "atom37_atom_exists", + ]: + structure[k] *= attention_mask.unsqueeze(-1) + structure["residue_index"] = position_ids + + lddt_head = self.lddt_head(structure["states"]).reshape(structure["states"].shape[0], B, L, -1, self.lddt_bins) + structure["lddt_head"] = lddt_head + plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins) + structure["plddt"] = plddt + + ptm_logits = self.ptm_head(structure["s_z"]) + structure["ptm_logits"] = ptm_logits + structure["ptm"] = compute_tm(ptm_logits, max_bin=31, no_bins=self.distogram_bins) + structure.update(compute_predicted_aligned_error(ptm_logits, max_bin=31, no_bins=self.distogram_bins)) + + return EsmForProteinFoldingOutput(**structure) + + def af2_idx_to_esm_idx(self, aa, mask): + aa = (aa + 1).masked_fill(mask != 1, 0) + return self.af2_to_esm[aa] + + def compute_language_model_representations(self, esmaa: torch.Tensor) -> torch.Tensor: + device = next(self.parameters()).device + B, L = esmaa.shape # B = batch size, L = sequence length. + + if self.config.esmfold_config.bypass_lm: + esm_s = torch.zeros(B, L, self.esm_s_combine.size[0], -1, self.esm_feats, device=device) + return esm_s + + bosi, eosi = self.esm_dict_cls_idx, self.esm_dict_eos_idx + bos = esmaa.new_full((B, 1), bosi) + eos = esmaa.new_full((B, 1), self.esm_dict_padding_idx) + esmaa = torch.cat([bos, esmaa, eos], dim=1) + # Use the first padding index as eos during inference. + esmaa[range(B), (esmaa != 1).sum(1)] = eosi + + # _, esm_z, esm_s = self.esm(esmaa, return_pairs=self.config.esmfold_config.use_esm_attn_map) + # Because we do not support use_esm_attn_map in the HF port as it is not used in any public models, + # esm_z is always None + esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"] + esm_s = torch.stack(esm_hidden_states, dim=2) + + esm_s = esm_s[:, 1:-1] # B, L, nLayers, C + + return esm_s + + def bert_mask(self, aa, esmaa, mask, pattern): + new_aa = aa.clone() + target = aa.clone() + new_esmaa = esmaa.clone() + new_aa[pattern == 1] = self.mask_idx + target[pattern != 1] = 0 + new_esmaa[pattern == 1] = self.esm_dict_mask_idx + return new_aa, new_esmaa, target + + @torch.no_grad() + def infer( + self, + seqs: Union[str, List[str]], + residx=None, + with_mask: Optional[torch.Tensor] = None, + ): + if type(seqs) is str: + lst = [seqs] + else: + lst = seqs + # Returns the raw outputs of the model given an input sequence. + device = next(self.parameters()).device + aatype = collate_dense_tensors( + [ + torch.from_numpy( + residue_constants.sequence_to_onehot( + sequence=seq, + mapping=residue_constants.restype_order_with_x, + map_unknown_to_x=True, + ) + ) + .to(device) + .argmax(dim=1) + for seq in lst + ] + ) # B=1 x L + mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst]) + residx = ( + torch.arange(aatype.shape[1], device=device).expand(len(lst), -1) if residx is None else residx.to(device) + ) + if residx.ndim == 1: + residx = residx.unsqueeze(0) + return self.forward( + aatype, + mask, + mask_aa=with_mask is not None, + masking_pattern=with_mask, + residx=residx, + ) + + @staticmethod + def output_to_pdb(output: Dict) -> List[str]: + """Returns the pbd (file) string from the model given the model output.""" + output = {k: v.to("cpu").numpy() for k, v in output.items()} + pdbs = [] + final_atom_positions = atom14_to_atom37(output["positions"][-1], output) + final_atom_mask = output["atom37_atom_exists"] + for i in range(output["aatype"].shape[0]): + aa = output["aatype"][i] + pred_pos = final_atom_positions[i] + mask = final_atom_mask[i] + resid = output["residue_index"][i] + 1 + pred = OFProtein( + aatype=aa, + atom_positions=pred_pos, + atom_mask=mask, + residue_index=resid, + b_factors=output["plddt"][i], + ) + pdbs.append(to_pdb(pred)) + return pdbs + + def infer_pdb(self, seqs, *args, **kwargs) -> str: + """Returns the pdb (file) string from the model given an input sequence.""" + assert type(seqs) is str + output = self.infer(seqs, *args, **kwargs) + return self.output_to_pdb(output)[0] + + def infer_pdbs(self, seqs: List[str], *args, **kwargs) -> List[str]: + """Returns the pdb (file) string from the model given an input sequence.""" + output = self.infer(seqs, *args, **kwargs) + return self.output_to_pdb(output) diff --git a/src/transformers/models/esm/openfold_utils/__init__.py b/src/transformers/models/esm/openfold_utils/__init__.py new file mode 100644 index 00000000000000..5273860260c19e --- /dev/null +++ b/src/transformers/models/esm/openfold_utils/__init__.py @@ -0,0 +1,8 @@ +# flake8: noqa +from .chunk_utils import chunk_layer +from .data_transforms import make_atom14_masks +from .feats import atom14_to_atom37, frames_and_literature_positions_to_atom14_pos, torsion_angles_to_frames +from .loss import compute_predicted_aligned_error, compute_tm +from .protein import Protein as OFProtein +from .protein import to_pdb +from .rigid_utils import Rigid, Rotation diff --git a/src/transformers/models/esm/openfold_utils/chunk_utils.py b/src/transformers/models/esm/openfold_utils/chunk_utils.py new file mode 100644 index 00000000000000..11e5fff929b28e --- /dev/null +++ b/src/transformers/models/esm/openfold_utils/chunk_utils.py @@ -0,0 +1,398 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import math +from functools import partial +from typing import Any, Callable, Dict, Optional, Sequence, Tuple + +import torch + +from .tensor_utils import tensor_tree_map, tree_map + + +def _fetch_dims(tree): + shapes = [] + tree_type = type(tree) + if tree_type is dict: + for v in tree.values(): + shapes.extend(_fetch_dims(v)) + elif tree_type is list or tree_type is tuple: + for t in tree: + shapes.extend(_fetch_dims(t)) + elif tree_type is torch.Tensor: + shapes.append(tree.shape) + else: + raise ValueError("Not supported") + + return shapes + + +@torch.jit.ignore +def _flat_idx_to_idx( + flat_idx: int, + dims: Tuple[int], +) -> Tuple[int]: + idx = [] + for d in reversed(dims): + idx.append(flat_idx % d) + flat_idx = flat_idx // d + + return tuple(reversed(idx)) + + +@torch.jit.ignore +def _get_minimal_slice_set( + start: Sequence[int], + end: Sequence[int], + dims: int, + start_edges: Optional[Sequence[bool]] = None, + end_edges: Optional[Sequence[bool]] = None, +) -> Sequence[Tuple[int]]: + """ + Produces an ordered sequence of tensor slices that, when used in sequence on a tensor with shape dims, yields + tensors that contain every leaf in the contiguous range [start, end]. Care is taken to yield a short sequence of + slices, and perhaps even the shortest possible (I'm pretty sure it's the latter). + + end is INCLUSIVE. + """ + # start_edges and end_edges both indicate whether, starting from any given + # dimension, the start/end index is at the top/bottom edge of the + # corresponding tensor, modeled as a tree + def reduce_edge_list(l): + tally = 1 + for i in range(len(l)): + reversed_idx = -1 * (i + 1) + l[reversed_idx] *= tally + tally = l[reversed_idx] + + if start_edges is None: + start_edges = [s == 0 for s in start] + reduce_edge_list(start_edges) + if end_edges is None: + end_edges = [e == (d - 1) for e, d in zip(end, dims)] + reduce_edge_list(end_edges) + + # Base cases. Either start/end are empty and we're done, or the final, + # one-dimensional tensor can be simply sliced + if len(start) == 0: + return [tuple()] + elif len(start) == 1: + return [(slice(start[0], end[0] + 1),)] + + slices = [] + path = [] + + # Dimensions common to start and end can be selected directly + for s, e in zip(start, end): + if s == e: + path.append(slice(s, s + 1)) + else: + break + + path = tuple(path) + divergence_idx = len(path) + + # start == end, and we're done + if divergence_idx == len(dims): + return [tuple(path)] + + def upper(): + sdi = start[divergence_idx] + return [ + path + (slice(sdi, sdi + 1),) + s + for s in _get_minimal_slice_set( + start[divergence_idx + 1 :], + [d - 1 for d in dims[divergence_idx + 1 :]], + dims[divergence_idx + 1 :], + start_edges=start_edges[divergence_idx + 1 :], + end_edges=[1 for _ in end_edges[divergence_idx + 1 :]], + ) + ] + + def lower(): + edi = end[divergence_idx] + return [ + path + (slice(edi, edi + 1),) + s + for s in _get_minimal_slice_set( + [0 for _ in start[divergence_idx + 1 :]], + end[divergence_idx + 1 :], + dims[divergence_idx + 1 :], + start_edges=[1 for _ in start_edges[divergence_idx + 1 :]], + end_edges=end_edges[divergence_idx + 1 :], + ) + ] + + # If both start and end are at the edges of the subtree rooted at + # divergence_idx, we can just select the whole subtree at once + if start_edges[divergence_idx] and end_edges[divergence_idx]: + slices.append(path + (slice(start[divergence_idx], end[divergence_idx] + 1),)) + # If just start is at the edge, we can grab almost all of the subtree, + # treating only the ragged bottom edge as an edge case + elif start_edges[divergence_idx]: + slices.append(path + (slice(start[divergence_idx], end[divergence_idx]),)) + slices.extend(lower()) + # Analogous to the previous case, but the top is ragged this time + elif end_edges[divergence_idx]: + slices.extend(upper()) + slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)) + # If both sides of the range are ragged, we need to handle both sides + # separately. If there's contiguous meat in between them, we can index it + # in one big chunk + else: + slices.extend(upper()) + middle_ground = end[divergence_idx] - start[divergence_idx] + if middle_ground > 1: + slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)) + slices.extend(lower()) + + return [tuple(s) for s in slices] + + +@torch.jit.ignore +def _chunk_slice( + t: torch.Tensor, + flat_start: int, + flat_end: int, + no_batch_dims: int, +) -> torch.Tensor: + """ + Equivalent to + + t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end] + + but without the need for the initial reshape call, which can be memory-intensive in certain situations. The only + reshape operations in this function are performed on sub-tensors that scale with (flat_end - flat_start), the chunk + size. + """ + + batch_dims = t.shape[:no_batch_dims] + start_idx = list(_flat_idx_to_idx(flat_start, batch_dims)) + # _get_minimal_slice_set is inclusive + end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims)) + + # Get an ordered list of slices to perform + slices = _get_minimal_slice_set( + start_idx, + end_idx, + batch_dims, + ) + + sliced_tensors = [t[s] for s in slices] + + return torch.cat([s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]) + + +def chunk_layer( + layer: Callable, + inputs: Dict[str, Any], + chunk_size: int, + no_batch_dims: int, + low_mem: bool = False, + _out: Any = None, + _add_into_out: bool = False, +) -> Any: + """ + Implements the "chunking" procedure described in section 1.11.8. + + Layer outputs and inputs are assumed to be simple "pytrees," consisting only of (arbitrarily nested) lists, tuples, + and dicts with torch.Tensor leaves. + + Args: + layer: + The layer to be applied chunk-wise + inputs: + A (non-nested) dictionary of keyworded inputs. All leaves must be tensors and must share the same batch + dimensions. + chunk_size: + The number of sub-batches per chunk. If multiple batch dimensions are specified, a "sub-batch" is defined + as a single indexing of all batch dimensions simultaneously (s.t. the number of sub-batches is the product + of the batch dimensions). + no_batch_dims: + How many of the initial dimensions of each input tensor can be considered batch dimensions. + low_mem: + Avoids flattening potentially large input tensors. Unnecessary in most cases, and is ever so slightly + slower than the default setting. + Returns: + The reassembled output of the layer on the inputs. + """ + if not (len(inputs) > 0): + raise ValueError("Must provide at least one input") + + initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)] + orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) + + def _prep_inputs(t): + if not low_mem: + if not sum(t.shape[:no_batch_dims]) == no_batch_dims: + t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) + t = t.reshape(-1, *t.shape[no_batch_dims:]) + else: + t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) + return t + + prepped_inputs = tensor_tree_map(_prep_inputs, inputs) + prepped_outputs = None + if _out is not None: + prepped_outputs = tensor_tree_map(lambda t: t.view([-1] + list(t.shape[no_batch_dims:])), _out) + + flat_batch_dim = 1 + for d in orig_batch_dims: + flat_batch_dim *= d + + no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0) + + def _select_chunk(t): + return t[i : i + chunk_size] if t.shape[0] != 1 else t + + i = 0 + out = prepped_outputs + for _ in range(no_chunks): + # Chunk the input + if not low_mem: + select_chunk = _select_chunk + else: + select_chunk = partial( + _chunk_slice, + flat_start=i, + flat_end=min(flat_batch_dim, i + chunk_size), + no_batch_dims=len(orig_batch_dims), + ) + + chunks = tensor_tree_map(select_chunk, prepped_inputs) + + # Run the layer on the chunk + output_chunk = layer(**chunks) + + # Allocate space for the output + if out is None: + out = tensor_tree_map(lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]), output_chunk) + + # Put the chunk in its pre-allocated space + out_type = type(output_chunk) + if out_type is dict: + + def assign(d1, d2): + for k, v in d1.items(): + if type(v) is dict: + assign(v, d2[k]) + else: + if _add_into_out: + v[i : i + chunk_size] += d2[k] + else: + v[i : i + chunk_size] = d2[k] + + assign(out, output_chunk) + elif out_type is tuple: + for x1, x2 in zip(out, output_chunk): + if _add_into_out: + x1[i : i + chunk_size] += x2 + else: + x1[i : i + chunk_size] = x2 + elif out_type is torch.Tensor: + if _add_into_out: + out[i : i + chunk_size] += output_chunk + else: + out[i : i + chunk_size] = output_chunk + else: + raise ValueError("Not supported") + + i += chunk_size + + out = tensor_tree_map(lambda t: t.view(orig_batch_dims + t.shape[1:]), out) + + return out + + +class ChunkSizeTuner: + def __init__( + self, + # Heuristically, runtimes for most of the modules in the network + # plateau earlier than this on all GPUs I've run the model on. + max_chunk_size=512, + ): + self.max_chunk_size = max_chunk_size + self.cached_chunk_size = None + self.cached_arg_data = None + + def _determine_favorable_chunk_size(self, fn, args, min_chunk_size): + logging.info("Tuning chunk size...") + + if min_chunk_size >= self.max_chunk_size: + return min_chunk_size + + candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)] + candidates = [c for c in candidates if c > min_chunk_size] + candidates = [min_chunk_size] + candidates + candidates[-1] += 4 + + def test_chunk_size(chunk_size): + try: + with torch.no_grad(): + fn(*args, chunk_size=chunk_size) + return True + except RuntimeError: + return False + + min_viable_chunk_size_index = 0 + i = len(candidates) - 1 + while i > min_viable_chunk_size_index: + viable = test_chunk_size(candidates[i]) + if not viable: + i = (min_viable_chunk_size_index + i) // 2 + else: + min_viable_chunk_size_index = i + i = (i + len(candidates) - 1) // 2 + + return candidates[min_viable_chunk_size_index] + + def _compare_arg_caches(self, ac1, ac2): + consistent = True + for a1, a2 in zip(ac1, ac2): + assert type(ac1) == type(ac2) + if type(ac1) is list or type(ac1) is tuple: + consistent &= self._compare_arg_caches(a1, a2) + elif type(ac1) is dict: + a1_items = [v for _, v in sorted(a1.items(), key=lambda x: x[0])] + a2_items = [v for _, v in sorted(a2.items(), key=lambda x: x[0])] + consistent &= self._compare_arg_caches(a1_items, a2_items) + else: + consistent &= a1 == a2 + + return consistent + + def tune_chunk_size( + self, + representative_fn: Callable, + args: Tuple[Any], + min_chunk_size: int, + ) -> int: + consistent = True + arg_data = tree_map(lambda a: a.shape if type(a) is torch.Tensor else a, args, object) + if self.cached_arg_data is not None: + # If args have changed shape/value, we need to re-tune + assert len(self.cached_arg_data) == len(arg_data) + consistent = self._compare_arg_caches(self.cached_arg_data, arg_data) + else: + # Otherwise, we can reuse the precomputed value + consistent = False + + if not consistent: + self.cached_chunk_size = self._determine_favorable_chunk_size( + representative_fn, + args, + min_chunk_size, + ) + self.cached_arg_data = arg_data + + return self.cached_chunk_size diff --git a/src/transformers/models/esm/openfold_utils/data_transforms.py b/src/transformers/models/esm/openfold_utils/data_transforms.py new file mode 100644 index 00000000000000..e9e5d693169d1c --- /dev/null +++ b/src/transformers/models/esm/openfold_utils/data_transforms.py @@ -0,0 +1,92 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + +from . import residue_constants as rc +from .tensor_utils import tensor_tree_map, tree_map + + +def make_atom14_masks(protein): + """Construct denser atom positions (14 dimensions instead of 37).""" + restype_atom14_to_atom37 = [] + restype_atom37_to_atom14 = [] + restype_atom14_mask = [] + + for rt in rc.restypes: + atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]] + restype_atom14_to_atom37.append([(rc.atom_order[name] if name else 0) for name in atom_names]) + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append( + [(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in rc.atom_types] + ) + + restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names]) + + # Add dummy mapping for restype 'UNK' + restype_atom14_to_atom37.append([0] * 14) + restype_atom37_to_atom14.append([0] * 37) + restype_atom14_mask.append([0.0] * 14) + + restype_atom14_to_atom37 = torch.tensor( + restype_atom14_to_atom37, + dtype=torch.int32, + device=protein["aatype"].device, + ) + restype_atom37_to_atom14 = torch.tensor( + restype_atom37_to_atom14, + dtype=torch.int32, + device=protein["aatype"].device, + ) + restype_atom14_mask = torch.tensor( + restype_atom14_mask, + dtype=torch.float32, + device=protein["aatype"].device, + ) + protein_aatype = protein["aatype"].to(torch.long) + + # create the mapping for (residx, atom14) --> atom37, i.e. an array + # with shape (num_res, 14) containing the atom37 indices for this protein + residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype] + residx_atom14_mask = restype_atom14_mask[protein_aatype] + + protein["atom14_atom_exists"] = residx_atom14_mask + protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long() + + # create the gather indices for mapping back + residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype] + protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long() + + # create the corresponding mask + restype_atom37_mask = torch.zeros([21, 37], dtype=torch.float32, device=protein["aatype"].device) + for restype, restype_letter in enumerate(rc.restypes): + restype_name = rc.restype_1to3[restype_letter] + atom_names = rc.residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = rc.atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + + residx_atom37_mask = restype_atom37_mask[protein_aatype] + protein["atom37_atom_exists"] = residx_atom37_mask + + return protein + + +def make_atom14_masks_np(batch): + batch = tree_map(lambda n: torch.tensor(n, device=batch["aatype"].device), batch, np.ndarray) + out = make_atom14_masks(batch) + out = tensor_tree_map(lambda t: np.array(t), out) + return out diff --git a/src/transformers/models/esm/openfold_utils/feats.py b/src/transformers/models/esm/openfold_utils/feats.py new file mode 100644 index 00000000000000..dbfda88805f7bf --- /dev/null +++ b/src/transformers/models/esm/openfold_utils/feats.py @@ -0,0 +1,234 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from . import residue_constants as rc +from .rigid_utils import Rigid, Rotation +from .tensor_utils import batched_gather + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): + is_gly = aatype == rc.restype_order["G"] + ca_idx = rc.atom_order["CA"] + cb_idx = rc.atom_order["CB"] + pseudo_beta = torch.where( + is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :], + ) + + if all_atom_masks is not None: + pseudo_beta_mask = torch.where( + is_gly, + all_atom_masks[..., ca_idx], + all_atom_masks[..., cb_idx], + ) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + +def atom14_to_atom37(atom14, batch): + atom37_data = batched_gather( + atom14, + batch["residx_atom37_to_atom14"], + dim=-2, + no_batch_dims=len(atom14.shape[:-2]), + ) + + atom37_data = atom37_data * batch["atom37_atom_exists"][..., None] + + return atom37_data + + +def build_template_angle_feat(template_feats): + template_aatype = template_feats["template_aatype"] + torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"] + alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"] + torsion_angles_mask = template_feats["template_torsion_angles_mask"] + template_angle_feat = torch.cat( + [ + nn.functional.one_hot(template_aatype, 22), + torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14), + alt_torsion_angles_sin_cos.reshape(*alt_torsion_angles_sin_cos.shape[:-2], 14), + torsion_angles_mask, + ], + dim=-1, + ) + + return template_angle_feat + + +def build_template_pair_feat(batch, min_bin, max_bin, no_bins, use_unit_vector=False, eps=1e-20, inf=1e8): + template_mask = batch["template_pseudo_beta_mask"] + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + + # Compute distogram (this seems to differ slightly from Alg. 5) + tpb = batch["template_pseudo_beta"] + dgram = torch.sum((tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True) + lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2 + upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) + dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) + + to_concat = [dgram, template_mask_2d[..., None]] + + aatype_one_hot = nn.functional.one_hot( + batch["template_aatype"], + rc.restype_num + 2, + ) + + n_res = batch["template_aatype"].shape[-1] + to_concat.append(aatype_one_hot[..., None, :, :].expand(*aatype_one_hot.shape[:-2], n_res, -1, -1)) + to_concat.append(aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1)) + + n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]] + rigids = Rigid.make_transform_from_reference( + n_xyz=batch["template_all_atom_positions"][..., n, :], + ca_xyz=batch["template_all_atom_positions"][..., ca, :], + c_xyz=batch["template_all_atom_positions"][..., c, :], + eps=eps, + ) + points = rigids.get_trans()[..., None, :, :] + rigid_vec = rigids[..., None].invert_apply(points) + + inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1)) + + t_aa_masks = batch["template_all_atom_mask"] + template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c] + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + + inv_distance_scalar = inv_distance_scalar * template_mask_2d + unit_vector = rigid_vec * inv_distance_scalar[..., None] + + if not use_unit_vector: + unit_vector = unit_vector * 0.0 + + to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1)) + to_concat.append(template_mask_2d[..., None]) + + act = torch.cat(to_concat, dim=-1) + act = act * template_mask_2d[..., None] + + return act + + +def build_extra_msa_feat(batch): + msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23) + msa_feat = [ + msa_1hot, + batch["extra_has_deletion"].unsqueeze(-1), + batch["extra_deletion_value"].unsqueeze(-1), + ] + return torch.cat(msa_feat, dim=-1) + + +def torsion_angles_to_frames( + r: Rigid, + alpha: torch.Tensor, + aatype: torch.Tensor, + rrgdf: torch.Tensor, +): + # [*, N, 8, 4, 4] + default_4x4 = rrgdf[aatype, ...] + + # [*, N, 8] transformations, i.e. + # One [*, N, 8, 3, 3] rotation matrix and + # One [*, N, 8, 3] translation matrix + default_r = r.from_tensor_4x4(default_4x4) + + bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) + bb_rot[..., 1] = 1 + + # [*, N, 8, 2] + alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2) + + # [*, N, 8, 3, 3] + # Produces rotation matrices of the form: + # [ + # [1, 0 , 0 ], + # [0, a_2,-a_1], + # [0, a_1, a_2] + # ] + # This follows the original code rather than the supplement, which uses + # different indices. + + all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) + all_rots[..., 0, 0] = 1 + all_rots[..., 1, 1] = alpha[..., 1] + all_rots[..., 1, 2] = -alpha[..., 0] + all_rots[..., 2, 1:] = alpha + + all_rots = Rigid(Rotation(rot_mats=all_rots), None) + + all_frames = default_r.compose(all_rots) + + chi2_frame_to_frame = all_frames[..., 5] + chi3_frame_to_frame = all_frames[..., 6] + chi4_frame_to_frame = all_frames[..., 7] + + chi1_frame_to_bb = all_frames[..., 4] + chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) + chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) + chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) + + all_frames_to_bb = Rigid.cat( + [ + all_frames[..., :5], + chi2_frame_to_bb.unsqueeze(-1), + chi3_frame_to_bb.unsqueeze(-1), + chi4_frame_to_bb.unsqueeze(-1), + ], + dim=-1, + ) + + all_frames_to_global = r[..., None].compose(all_frames_to_bb) + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + r: Rigid, + aatype: torch.Tensor, + default_frames, + group_idx, + atom_mask, + lit_positions, +): + # [*, N, 14] + group_mask = group_idx[aatype, ...] + + # [*, N, 14, 8] + group_mask = nn.functional.one_hot( + group_mask, + num_classes=default_frames.shape[-3], + ) + + # [*, N, 14, 8] + t_atoms_to_global = r[..., None, :] * group_mask + + # [*, N, 14] + t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) + + # [*, N, 14, 1] + atom_mask = atom_mask[aatype, ...].unsqueeze(-1) + + # [*, N, 14, 3] + lit_positions = lit_positions[aatype, ...] + pred_positions = t_atoms_to_global.apply(lit_positions) + pred_positions = pred_positions * atom_mask + + return pred_positions diff --git a/src/transformers/models/esm/openfold_utils/loss.py b/src/transformers/models/esm/openfold_utils/loss.py new file mode 100644 index 00000000000000..4d60b1049137c5 --- /dev/null +++ b/src/transformers/models/esm/openfold_utils/loss.py @@ -0,0 +1,105 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Tuple + +import torch + + +def _calculate_bin_centers(boundaries: torch.Tensor): + step = boundaries[1] - boundaries[0] + bin_centers = boundaries + step / 2 + bin_centers = torch.cat([bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0) + return bin_centers + + +def _calculate_expected_aligned_error( + alignment_confidence_breaks: torch.Tensor, + aligned_distance_error_probs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + bin_centers = _calculate_bin_centers(alignment_confidence_breaks) + return ( + torch.sum(aligned_distance_error_probs * bin_centers, dim=-1), + bin_centers[-1], + ) + + +def compute_predicted_aligned_error( + logits: torch.Tensor, + max_bin: int = 31, + no_bins: int = 64, + **kwargs, +) -> Dict[str, torch.Tensor]: + """Computes aligned confidence metrics from logits. + + Args: + logits: [*, num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + max_bin: Maximum bin value + no_bins: Number of bins + Returns: + aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted + aligned error probabilities over bins for each residue pair. + predicted_aligned_error: [*, num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: [*] the maximum predicted error possible. + """ + boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device) + + aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1) + (predicted_aligned_error, max_predicted_aligned_error,) = _calculate_expected_aligned_error( + alignment_confidence_breaks=boundaries, + aligned_distance_error_probs=aligned_confidence_probs, + ) + + return { + "aligned_confidence_probs": aligned_confidence_probs, + "predicted_aligned_error": predicted_aligned_error, + "max_predicted_aligned_error": max_predicted_aligned_error, + } + + +def compute_tm( + logits: torch.Tensor, + residue_weights: Optional[torch.Tensor] = None, + max_bin: int = 31, + no_bins: int = 64, + eps: float = 1e-8, + **kwargs, +) -> torch.Tensor: + if residue_weights is None: + residue_weights = logits.new_ones(logits.shape[-2]) + + boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device) + + bin_centers = _calculate_bin_centers(boundaries) + torch.sum(residue_weights) + n = logits.shape[-2] + clipped_n = max(n, 19) + + d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8 + + probs = torch.nn.functional.softmax(logits, dim=-1) + + tm_per_bin = 1.0 / (1 + (bin_centers**2) / (d0**2)) + predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1) + + normed_residue_mask = residue_weights / (eps + residue_weights.sum()) + per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1) + + weighted = per_alignment * residue_weights + + argmax = (weighted == torch.max(weighted)).nonzero()[0] + return per_alignment[tuple(argmax)] diff --git a/src/transformers/models/esm/openfold_utils/protein.py b/src/transformers/models/esm/openfold_utils/protein.py new file mode 100644 index 00000000000000..750027117a8164 --- /dev/null +++ b/src/transformers/models/esm/openfold_utils/protein.py @@ -0,0 +1,329 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Protein data type.""" +import dataclasses +import re +import string +from typing import Any, Mapping, Optional, Sequence + +import numpy as np + +from . import residue_constants + + +FeatureDict = Mapping[str, np.ndarray] +ModelOutput = Mapping[str, Any] # Is a nested dict. +PICO_TO_ANGSTROM = 0.01 + + +@dataclasses.dataclass(frozen=True) +class Protein: + """Protein structure representation.""" + + # Cartesian coordinates of atoms in angstroms. The atom types correspond to + # residue_constants.atom_types, i.e. the first three are N, CA, CB. + atom_positions: np.ndarray # [num_res, num_atom_type, 3] + + # Amino-acid type for each residue represented as an integer between 0 and + # 20, where 20 is 'X'. + aatype: np.ndarray # [num_res] + + # Binary float mask to indicate presence of a particular atom. 1.0 if an atom + # is present and 0.0 if not. This should be used for loss masking. + atom_mask: np.ndarray # [num_res, num_atom_type] + + # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. + residue_index: np.ndarray # [num_res] + + # B-factors, or temperature factors, of each residue (in sq. angstroms units), + # representing the displacement of the residue from its ground truth mean + # value. + b_factors: np.ndarray # [num_res, num_atom_type] + + # Chain indices for multi-chain predictions + chain_index: Optional[np.ndarray] = None + + # Optional remark about the protein. Included as a comment in output PDB + # files + remark: Optional[str] = None + + # Templates used to generate this protein (prediction-only) + parents: Optional[Sequence[str]] = None + + # Chain corresponding to each parent + parents_chain_index: Optional[Sequence[int]] = None + + +def from_proteinnet_string(proteinnet_str: str) -> Protein: + tag_re = r"(\[[A-Z]+\]\n)" + tags = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0] + groups = zip(tags[0::2], [l.split("\n") for l in tags[1::2]]) + + atoms = ["N", "CA", "C"] + aatype = None + atom_positions = None + atom_mask = None + for g in groups: + if "[PRIMARY]" == g[0]: + seq = g[1][0].strip() + for i in range(len(seq)): + if seq[i] not in residue_constants.restypes: + seq[i] = "X" + aatype = np.array( + [residue_constants.restype_order.get(res_symbol, residue_constants.restype_num) for res_symbol in seq] + ) + elif "[TERTIARY]" == g[0]: + tertiary = [] + for axis in range(3): + tertiary.append(list(map(float, g[1][axis].split()))) + tertiary_np = np.array(tertiary) + atom_positions = np.zeros((len(tertiary[0]) // 3, residue_constants.atom_type_num, 3)).astype(np.float32) + for i, atom in enumerate(atoms): + atom_positions[:, residue_constants.atom_order[atom], :] = np.transpose(tertiary_np[:, i::3]) + atom_positions *= PICO_TO_ANGSTROM + elif "[MASK]" == g[0]: + mask = np.array(list(map({"-": 0, "+": 1}.get, g[1][0].strip()))) + atom_mask = np.zeros( + ( + len(mask), + residue_constants.atom_type_num, + ) + ).astype(np.float32) + for i, atom in enumerate(atoms): + atom_mask[:, residue_constants.atom_order[atom]] = 1 + atom_mask *= mask[..., None] + + return Protein( + atom_positions=atom_positions, + atom_mask=atom_mask, + aatype=aatype, + residue_index=np.arange(len(aatype)), + b_factors=None, + ) + + +def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]: + pdb_headers = [] + + remark = prot.remark + if remark is not None: + pdb_headers.append(f"REMARK {remark}") + + parents = prot.parents + parents_chain_index = prot.parents_chain_index + if parents_chain_index is not None: + parents = [p for i, p in zip(parents_chain_index, parents) if i == chain_id] + + if parents is None or len(parents) == 0: + parents = ["N/A"] + + pdb_headers.append(f"PARENT {' '.join(parents)}") + + return pdb_headers + + +def add_pdb_headers(prot: Protein, pdb_str: str) -> str: + """Add pdb headers to an existing PDB string. Useful during multi-chain + recycling + """ + out_pdb_lines = [] + lines = pdb_str.split("\n") + + remark = prot.remark + if remark is not None: + out_pdb_lines.append(f"REMARK {remark}") + + parents_per_chain = None + if prot.parents is not None and len(prot.parents) > 0: + parents_per_chain = [] + if prot.parents_chain_index is not None: + parent_dict = {} + for p, i in zip(prot.parents, prot.parents_chain_index): + parent_dict.setdefault(str(i), []) + parent_dict[str(i)].append(p) + + max_idx = max([int(chain_idx) for chain_idx in parent_dict]) + for i in range(max_idx + 1): + chain_parents = parent_dict.get(str(i), ["N/A"]) + parents_per_chain.append(chain_parents) + else: + parents_per_chain.append(prot.parents) + else: + parents_per_chain = [["N/A"]] + + def make_parent_line(p): + return f"PARENT {' '.join(p)}" + + out_pdb_lines.append(make_parent_line(parents_per_chain[0])) + + chain_counter = 0 + for i, l in enumerate(lines): + if "PARENT" not in l and "REMARK" not in l: + out_pdb_lines.append(l) + if "TER" in l and "END" not in lines[i + 1]: + chain_counter += 1 + if not chain_counter >= len(parents_per_chain): + chain_parents = parents_per_chain[chain_counter] + else: + chain_parents = ["N/A"] + + out_pdb_lines.append(make_parent_line(chain_parents)) + + return "\n".join(out_pdb_lines) + + +def to_pdb(prot: Protein) -> str: + """Converts a `Protein` instance to a PDB string. + + Args: + prot: The protein to convert to PDB. + + Returns: + PDB string. + """ + restypes = residue_constants.restypes + ["X"] + + def res_1to3(r): + return residue_constants.restype_1to3.get(restypes[r], "UNK") + + atom_types = residue_constants.atom_types + + pdb_lines = [] + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(np.int32) + b_factors = prot.b_factors + chain_index = prot.chain_index + + if np.any(aatype > residue_constants.restype_num): + raise ValueError("Invalid aatypes.") + + headers = get_pdb_headers(prot) + if len(headers) > 0: + pdb_lines.extend(headers) + + n = aatype.shape[0] + atom_index = 1 + prev_chain_index = 0 + chain_tags = string.ascii_uppercase + # Add all atom sites. + for i in range(n): + res_name_3 = res_1to3(aatype[i]) + for atom_name, pos, mask, b_factor in zip(atom_types, atom_positions[i], atom_mask[i], b_factors[i]): + if mask < 0.5: + continue + + record_type = "ATOM" + name = atom_name if len(atom_name) == 4 else f" {atom_name}" + alt_loc = "" + insertion_code = "" + occupancy = 1.00 + element = atom_name[0] # Protein supports only C, N, O, S, this works. + charge = "" + + chain_tag = "A" + if chain_index is not None: + chain_tag = chain_tags[chain_index[i]] + + # PDB is a columnar format, every space matters here! + atom_line = ( + f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}" + f"{res_name_3:>3} {chain_tag:>1}" + f"{residue_index[i]:>4}{insertion_code:>1} " + f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}" + f"{occupancy:>6.2f}{b_factor:>6.2f} " + f"{element:>2}{charge:>2}" + ) + pdb_lines.append(atom_line) + atom_index += 1 + + should_terminate = i == n - 1 + if chain_index is not None: + if i != n - 1 and chain_index[i + 1] != prev_chain_index: + should_terminate = True + prev_chain_index = chain_index[i + 1] + + if should_terminate: + # Close the chain. + chain_end = "TER" + chain_termination_line = ( + f"{chain_end:<6}{atom_index:>5} {res_1to3(aatype[i]):>3} {chain_tag:>1}{residue_index[i]:>4}" + ) + pdb_lines.append(chain_termination_line) + atom_index += 1 + + if i != n - 1: + # "prev" is a misnomer here. This happens at the beginning of + # each new chain. + pdb_lines.extend(get_pdb_headers(prot, prev_chain_index)) + + pdb_lines.append("END") + pdb_lines.append("") + return "\n".join(pdb_lines) + + +def ideal_atom_mask(prot: Protein) -> np.ndarray: + """Computes an ideal atom mask. + + `Protein.atom_mask` typically is defined according to the atoms that are reported in the PDB. This function + computes a mask according to heavy atoms that should be present in the given sequence of amino acids. + + Args: + prot: `Protein` whose fields are `numpy.ndarray` objects. + + Returns: + An ideal atom mask. + """ + return residue_constants.STANDARD_ATOM_MASK[prot.aatype] + + +def from_prediction( + features: FeatureDict, + result: ModelOutput, + b_factors: Optional[np.ndarray] = None, + chain_index: Optional[np.ndarray] = None, + remark: Optional[str] = None, + parents: Optional[Sequence[str]] = None, + parents_chain_index: Optional[Sequence[int]] = None, +) -> Protein: + """Assembles a protein from a prediction. + + Args: + features: Dictionary holding model inputs. + result: Dictionary holding model outputs. + b_factors: (Optional) B-factors to use for the protein. + chain_index: (Optional) Chain indices for multi-chain predictions + remark: (Optional) Remark about the prediction + parents: (Optional) List of template names + Returns: + A protein instance. + """ + if b_factors is None: + b_factors = np.zeros_like(result["final_atom_mask"]) + + return Protein( + aatype=features["aatype"], + atom_positions=result["final_atom_positions"], + atom_mask=result["final_atom_mask"], + residue_index=features["residue_index"] + 1, + b_factors=b_factors, + chain_index=chain_index, + remark=remark, + parents=parents, + parents_chain_index=parents_chain_index, + ) diff --git a/src/transformers/models/esm/openfold_utils/residue_constants.py b/src/transformers/models/esm/openfold_utils/residue_constants.py new file mode 100644 index 00000000000000..37f0fc081d069e --- /dev/null +++ b/src/transformers/models/esm/openfold_utils/residue_constants.py @@ -0,0 +1,1251 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants used in AlphaFold.""" + +import collections +import copy +import functools +from importlib import resources +from typing import List, Mapping, Tuple + +import numpy as np + + +# Internal import (35fd). + + +# Distance from one CA to next CA [trans configuration: omega = 180]. +ca_ca = 3.80209737096 + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +chi_angles_atoms = { + "ALA": [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + "ARG": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "NE"], + ["CG", "CD", "NE", "CZ"], + ], + "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], + "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], + "CYS": [["N", "CA", "CB", "SG"]], + "GLN": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "OE1"], + ], + "GLU": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "OE1"], + ], + "GLY": [], + "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]], + "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]], + "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "LYS": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "CE"], + ["CG", "CD", "CE", "NZ"], + ], + "MET": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "SD"], + ["CB", "CG", "SD", "CE"], + ], + "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]], + "SER": [["N", "CA", "CB", "OG"]], + "THR": [["N", "CA", "CB", "OG1"]], + "TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "VAL": [["N", "CA", "CB", "CG1"]], +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +chi_angles_mask = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [1.0, 1.0, 1.0, 1.0], # ARG + [1.0, 1.0, 0.0, 0.0], # ASN + [1.0, 1.0, 0.0, 0.0], # ASP + [1.0, 0.0, 0.0, 0.0], # CYS + [1.0, 1.0, 1.0, 0.0], # GLN + [1.0, 1.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [1.0, 1.0, 0.0, 0.0], # HIS + [1.0, 1.0, 0.0, 0.0], # ILE + [1.0, 1.0, 0.0, 0.0], # LEU + [1.0, 1.0, 1.0, 1.0], # LYS + [1.0, 1.0, 1.0, 0.0], # MET + [1.0, 1.0, 0.0, 0.0], # PHE + [1.0, 1.0, 0.0, 0.0], # PRO + [1.0, 0.0, 0.0, 0.0], # SER + [1.0, 0.0, 0.0, 0.0], # THR + [1.0, 1.0, 0.0, 0.0], # TRP + [1.0, 1.0, 0.0, 0.0], # TYR + [1.0, 0.0, 0.0, 0.0], # VAL +] + +# The following chi angles are pi periodic: they can be rotated by a multiple +# of pi without affecting the structure. +chi_pi_periodic = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [0.0, 0.0, 0.0, 0.0], # ARG + [0.0, 0.0, 0.0, 0.0], # ASN + [0.0, 1.0, 0.0, 0.0], # ASP + [0.0, 0.0, 0.0, 0.0], # CYS + [0.0, 0.0, 0.0, 0.0], # GLN + [0.0, 0.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [0.0, 0.0, 0.0, 0.0], # HIS + [0.0, 0.0, 0.0, 0.0], # ILE + [0.0, 0.0, 0.0, 0.0], # LEU + [0.0, 0.0, 0.0, 0.0], # LYS + [0.0, 0.0, 0.0, 0.0], # MET + [0.0, 1.0, 0.0, 0.0], # PHE + [0.0, 0.0, 0.0, 0.0], # PRO + [0.0, 0.0, 0.0, 0.0], # SER + [0.0, 0.0, 0.0, 0.0], # THR + [0.0, 0.0, 0.0, 0.0], # TRP + [0.0, 1.0, 0.0, 0.0], # TYR + [0.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # UNK +] + +# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi, +# psi and chi angles: +# 0: 'backbone group', +# 1: 'pre-omega-group', (empty) +# 2: 'phi-group', (currently empty, because it defines only hydrogens) +# 3: 'psi-group', +# 4,5,6,7: 'chi1,2,3,4-group' +# The atom positions are relative to the axis-end-atom of the corresponding +# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis +# is defined such that the dihedral-angle-definiting atom (the last entry in +# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). +# format: [atomname, group_idx, rel_position] +rigid_group_atom_positions = { + "ALA": [ + ["N", 0, (-0.525, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, -0.000, -0.000)], + ["CB", 0, (-0.529, -0.774, -1.205)], + ["O", 3, (0.627, 1.062, 0.000)], + ], + "ARG": [ + ["N", 0, (-0.524, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, -0.000)], + ["CB", 0, (-0.524, -0.778, -1.209)], + ["O", 3, (0.626, 1.062, 0.000)], + ["CG", 4, (0.616, 1.390, -0.000)], + ["CD", 5, (0.564, 1.414, 0.000)], + ["NE", 6, (0.539, 1.357, -0.000)], + ["NH1", 7, (0.206, 2.301, 0.000)], + ["NH2", 7, (2.078, 0.978, -0.000)], + ["CZ", 7, (0.758, 1.093, -0.000)], + ], + "ASN": [ + ["N", 0, (-0.536, 1.357, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, -0.000, -0.000)], + ["CB", 0, (-0.531, -0.787, -1.200)], + ["O", 3, (0.625, 1.062, 0.000)], + ["CG", 4, (0.584, 1.399, 0.000)], + ["ND2", 5, (0.593, -1.188, 0.001)], + ["OD1", 5, (0.633, 1.059, 0.000)], + ], + "ASP": [ + ["N", 0, (-0.525, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, 0.000, -0.000)], + ["CB", 0, (-0.526, -0.778, -1.208)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.593, 1.398, -0.000)], + ["OD1", 5, (0.610, 1.091, 0.000)], + ["OD2", 5, (0.592, -1.101, -0.003)], + ], + "CYS": [ + ["N", 0, (-0.522, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.524, 0.000, 0.000)], + ["CB", 0, (-0.519, -0.773, -1.212)], + ["O", 3, (0.625, 1.062, -0.000)], + ["SG", 4, (0.728, 1.653, 0.000)], + ], + "GLN": [ + ["N", 0, (-0.526, 1.361, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, 0.000, 0.000)], + ["CB", 0, (-0.525, -0.779, -1.207)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.615, 1.393, 0.000)], + ["CD", 5, (0.587, 1.399, -0.000)], + ["NE2", 6, (0.593, -1.189, -0.001)], + ["OE1", 6, (0.634, 1.060, 0.000)], + ], + "GLU": [ + ["N", 0, (-0.528, 1.361, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, -0.000, -0.000)], + ["CB", 0, (-0.526, -0.781, -1.207)], + ["O", 3, (0.626, 1.062, 0.000)], + ["CG", 4, (0.615, 1.392, 0.000)], + ["CD", 5, (0.600, 1.397, 0.000)], + ["OE1", 6, (0.607, 1.095, -0.000)], + ["OE2", 6, (0.589, -1.104, -0.001)], + ], + "GLY": [ + ["N", 0, (-0.572, 1.337, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.517, -0.000, -0.000)], + ["O", 3, (0.626, 1.062, -0.000)], + ], + "HIS": [ + ["N", 0, (-0.527, 1.360, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, 0.000, 0.000)], + ["CB", 0, (-0.525, -0.778, -1.208)], + ["O", 3, (0.625, 1.063, 0.000)], + ["CG", 4, (0.600, 1.370, -0.000)], + ["CD2", 5, (0.889, -1.021, 0.003)], + ["ND1", 5, (0.744, 1.160, -0.000)], + ["CE1", 5, (2.030, 0.851, 0.002)], + ["NE2", 5, (2.145, -0.466, 0.004)], + ], + "ILE": [ + ["N", 0, (-0.493, 1.373, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, -0.000, -0.000)], + ["CB", 0, (-0.536, -0.793, -1.213)], + ["O", 3, (0.627, 1.062, -0.000)], + ["CG1", 4, (0.534, 1.437, -0.000)], + ["CG2", 4, (0.540, -0.785, -1.199)], + ["CD1", 5, (0.619, 1.391, 0.000)], + ], + "LEU": [ + ["N", 0, (-0.520, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, -0.000)], + ["CB", 0, (-0.522, -0.773, -1.214)], + ["O", 3, (0.625, 1.063, -0.000)], + ["CG", 4, (0.678, 1.371, 0.000)], + ["CD1", 5, (0.530, 1.430, -0.000)], + ["CD2", 5, (0.535, -0.774, 1.200)], + ], + "LYS": [ + ["N", 0, (-0.526, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, 0.000, 0.000)], + ["CB", 0, (-0.524, -0.778, -1.208)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.619, 1.390, 0.000)], + ["CD", 5, (0.559, 1.417, 0.000)], + ["CE", 6, (0.560, 1.416, 0.000)], + ["NZ", 7, (0.554, 1.387, 0.000)], + ], + "MET": [ + ["N", 0, (-0.521, 1.364, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, 0.000, 0.000)], + ["CB", 0, (-0.523, -0.776, -1.210)], + ["O", 3, (0.625, 1.062, -0.000)], + ["CG", 4, (0.613, 1.391, -0.000)], + ["SD", 5, (0.703, 1.695, 0.000)], + ["CE", 6, (0.320, 1.786, -0.000)], + ], + "PHE": [ + ["N", 0, (-0.518, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.524, 0.000, -0.000)], + ["CB", 0, (-0.525, -0.776, -1.212)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.607, 1.377, 0.000)], + ["CD1", 5, (0.709, 1.195, -0.000)], + ["CD2", 5, (0.706, -1.196, 0.000)], + ["CE1", 5, (2.102, 1.198, -0.000)], + ["CE2", 5, (2.098, -1.201, -0.000)], + ["CZ", 5, (2.794, -0.003, -0.001)], + ], + "PRO": [ + ["N", 0, (-0.566, 1.351, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, -0.000, 0.000)], + ["CB", 0, (-0.546, -0.611, -1.293)], + ["O", 3, (0.621, 1.066, 0.000)], + ["CG", 4, (0.382, 1.445, 0.0)], + # ['CD', 5, (0.427, 1.440, 0.0)], + ["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger + ], + "SER": [ + ["N", 0, (-0.529, 1.360, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, -0.000)], + ["CB", 0, (-0.518, -0.777, -1.211)], + ["O", 3, (0.626, 1.062, -0.000)], + ["OG", 4, (0.503, 1.325, 0.000)], + ], + "THR": [ + ["N", 0, (-0.517, 1.364, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, 0.000, -0.000)], + ["CB", 0, (-0.516, -0.793, -1.215)], + ["O", 3, (0.626, 1.062, 0.000)], + ["CG2", 4, (0.550, -0.718, -1.228)], + ["OG1", 4, (0.472, 1.353, 0.000)], + ], + "TRP": [ + ["N", 0, (-0.521, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, 0.000)], + ["CB", 0, (-0.523, -0.776, -1.212)], + ["O", 3, (0.627, 1.062, 0.000)], + ["CG", 4, (0.609, 1.370, -0.000)], + ["CD1", 5, (0.824, 1.091, 0.000)], + ["CD2", 5, (0.854, -1.148, -0.005)], + ["CE2", 5, (2.186, -0.678, -0.007)], + ["CE3", 5, (0.622, -2.530, -0.007)], + ["NE1", 5, (2.140, 0.690, -0.004)], + ["CH2", 5, (3.028, -2.890, -0.013)], + ["CZ2", 5, (3.283, -1.543, -0.011)], + ["CZ3", 5, (1.715, -3.389, -0.011)], + ], + "TYR": [ + ["N", 0, (-0.522, 1.362, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.524, -0.000, -0.000)], + ["CB", 0, (-0.522, -0.776, -1.213)], + ["O", 3, (0.627, 1.062, -0.000)], + ["CG", 4, (0.607, 1.382, -0.000)], + ["CD1", 5, (0.716, 1.195, -0.000)], + ["CD2", 5, (0.713, -1.194, -0.001)], + ["CE1", 5, (2.107, 1.200, -0.002)], + ["CE2", 5, (2.104, -1.201, -0.003)], + ["OH", 5, (4.168, -0.002, -0.005)], + ["CZ", 5, (2.791, -0.001, -0.003)], + ], + "VAL": [ + ["N", 0, (-0.494, 1.373, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, -0.000, -0.000)], + ["CB", 0, (-0.533, -0.795, -1.213)], + ["O", 3, (0.627, 1.062, -0.000)], + ["CG1", 4, (0.540, 1.429, -0.000)], + ["CG2", 4, (0.533, -0.776, 1.203)], + ], +} + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +residue_atoms = { + "ALA": ["C", "CA", "CB", "N", "O"], + "ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"], + "ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"], + "ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"], + "CYS": ["C", "CA", "CB", "N", "O", "SG"], + "GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"], + "GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"], + "GLY": ["C", "CA", "N", "O"], + "HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"], + "ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"], + "LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"], + "LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"], + "MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"], + "PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"], + "PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"], + "SER": ["C", "CA", "CB", "N", "O", "OG"], + "THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"], + "TRP": [ + "C", + "CA", + "CB", + "CG", + "CD1", + "CD2", + "CE2", + "CE3", + "CZ2", + "CZ3", + "CH2", + "N", + "NE1", + "O", + ], + "TYR": [ + "C", + "CA", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "N", + "O", + "OH", + ], + "VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"], +} + +# Naming swaps for ambiguous atom names. +# Due to symmetries in the amino acids the naming of atoms is ambiguous in +# 4 of the 20 amino acids. +# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities +# in LEU, VAL and ARG can be resolved by using the 3d constellations of +# the 'ambiguous' atoms and their neighbours) +# TODO: ^ interpret this +residue_atom_renaming_swaps = { + "ASP": {"OD1": "OD2"}, + "GLU": {"OE1": "OE2"}, + "PHE": {"CD1": "CD2", "CE1": "CE2"}, + "TYR": {"CD1": "CD2", "CE1": "CE2"}, +} + +# Van der Waals radii [Angstroem] of the atoms (from Wikipedia) +van_der_waals_radius = { + "C": 1.7, + "N": 1.55, + "O": 1.52, + "S": 1.8, +} + +Bond = collections.namedtuple("Bond", ["atom1_name", "atom2_name", "length", "stddev"]) +BondAngle = collections.namedtuple( + "BondAngle", + ["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"], +) + + +def map_structure_with_atom_order(in_list: List, first_call: bool = True): + # Maps strings in a nested list structure to their corresponding index in atom_order + if first_call: + in_list = copy.deepcopy(in_list) + for i in range(len(in_list)): + if isinstance(in_list[i], list): + in_list[i] = map_structure_with_atom_order(in_list[i], first_call=False) + elif isinstance(in_list[i], str): + in_list[i] = atom_order[in_list[i]] + else: + raise ValueError("Unexpected type when mapping nested lists!") + return in_list + + +@functools.lru_cache(maxsize=None) +def load_stereo_chemical_props() -> Tuple[ + Mapping[str, List[Bond]], + Mapping[str, List[Bond]], + Mapping[str, List[BondAngle]], +]: + """Load stereo_chemical_props.txt into a nice structure. + + Load literature values for bond lengths and bond angles and translate bond angles into the length of the opposite + edge of the triangle ("residue_virtual_bonds"). + + Returns: + residue_bonds: dict that maps resname --> list of Bond tuples residue_virtual_bonds: dict that maps resname --> + list of Bond tuples residue_bond_angles: dict that maps resname --> list of BondAngle tuples + """ + # TODO: this file should be downloaded in a setup script + stereo_chemical_props = resources.read_text("openfold.resources", "stereo_chemical_props.txt") + + lines_iter = iter(stereo_chemical_props.splitlines()) + # Load bond lengths. + residue_bonds = {} + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == "-": + break + bond, resname, length, stddev = line.split() + atom1, atom2 = bond.split("-") + if resname not in residue_bonds: + residue_bonds[resname] = [] + residue_bonds[resname].append(Bond(atom1, atom2, float(length), float(stddev))) + residue_bonds["UNK"] = [] + + # Load bond angles. + residue_bond_angles = {} + next(lines_iter) # Skip empty line. + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == "-": + break + bond, resname, angle_degree, stddev_degree = line.split() + atom1, atom2, atom3 = bond.split("-") + if resname not in residue_bond_angles: + residue_bond_angles[resname] = [] + residue_bond_angles[resname].append( + BondAngle( + atom1, + atom2, + atom3, + float(angle_degree) / 180.0 * np.pi, + float(stddev_degree) / 180.0 * np.pi, + ) + ) + residue_bond_angles["UNK"] = [] + + def make_bond_key(atom1_name, atom2_name): + """Unique key to lookup bonds.""" + return "-".join(sorted([atom1_name, atom2_name])) + + # Translate bond angles into distances ("virtual bonds"). + residue_virtual_bonds = {} + for resname, bond_angles in residue_bond_angles.items(): + # Create a fast lookup dict for bond lengths. + bond_cache = {} + for b in residue_bonds[resname]: + bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b + residue_virtual_bonds[resname] = [] + for ba in bond_angles: + bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)] + bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)] + + # Compute distance between atom1 and atom3 using the law of cosines + # c^2 = a^2 + b^2 - 2ab*cos(gamma). + gamma = ba.angle_rad + length = np.sqrt(bond1.length**2 + bond2.length**2 - 2 * bond1.length * bond2.length * np.cos(gamma)) + + # Propagation of uncertainty assuming uncorrelated errors. + dl_outer = 0.5 / length + dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer + dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer + dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer + stddev = np.sqrt( + (dl_dgamma * ba.stddev) ** 2 + (dl_db1 * bond1.stddev) ** 2 + (dl_db2 * bond2.stddev) ** 2 + ) + residue_virtual_bonds[resname].append(Bond(ba.atom1_name, ba.atom3name, length, stddev)) + + return (residue_bonds, residue_virtual_bonds, residue_bond_angles) + + +# Between-residue bond lengths for general bonds (first element) and for Proline +# (second element). +between_res_bond_length_c_n = [1.329, 1.341] +between_res_bond_length_stddev_c_n = [0.014, 0.016] + +# Between-residue cos_angles. +between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315 +between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995 + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_types = [ + "N", + "CA", + "C", + "CB", + "O", + "CG", + "CG1", + "CG2", + "OG", + "OG1", + "SG", + "CD", + "CD1", + "CD2", + "ND1", + "ND2", + "OD1", + "OD2", + "SD", + "CE", + "CE1", + "CE2", + "CE3", + "NE", + "NE1", + "NE2", + "OE1", + "OE2", + "CH2", + "NH1", + "NH2", + "OH", + "CZ", + "CZ2", + "CZ3", + "NZ", + "OXT", +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +# A compact atom encoding with 14 columns +# pylint: disable=line-too-long +# pylint: disable=bad-whitespace +restype_name_to_atom14_names = { + "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""], + "ARG": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "NE", + "CZ", + "NH1", + "NH2", + "", + "", + "", + ], + "ASN": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "OD1", + "ND2", + "", + "", + "", + "", + "", + "", + ], + "ASP": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "OD1", + "OD2", + "", + "", + "", + "", + "", + "", + ], + "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""], + "GLN": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "OE1", + "NE2", + "", + "", + "", + "", + "", + ], + "GLU": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "OE1", + "OE2", + "", + "", + "", + "", + "", + ], + "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""], + "HIS": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "ND1", + "CD2", + "CE1", + "NE2", + "", + "", + "", + "", + ], + "ILE": [ + "N", + "CA", + "C", + "O", + "CB", + "CG1", + "CG2", + "CD1", + "", + "", + "", + "", + "", + "", + ], + "LEU": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "", + "", + "", + "", + "", + "", + ], + "LYS": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "CE", + "NZ", + "", + "", + "", + "", + "", + ], + "MET": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "SD", + "CE", + "", + "", + "", + "", + "", + "", + ], + "PHE": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "", + "", + "", + ], + "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""], + "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""], + "THR": [ + "N", + "CA", + "C", + "O", + "CB", + "OG1", + "CG2", + "", + "", + "", + "", + "", + "", + "", + ], + "TRP": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "NE1", + "CE2", + "CE3", + "CZ2", + "CZ3", + "CH2", + ], + "TYR": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "OH", + "", + "", + ], + "VAL": [ + "N", + "CA", + "C", + "O", + "CB", + "CG1", + "CG2", + "", + "", + "", + "", + "", + "", + "", + ], + "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""], +} +# pylint: enable=line-too-long +# pylint: enable=bad-whitespace + + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes = [ + "A", + "R", + "N", + "D", + "C", + "Q", + "E", + "G", + "H", + "I", + "L", + "K", + "M", + "F", + "P", + "S", + "T", + "W", + "Y", + "V", +] +restype_order = {restype: i for i, restype in enumerate(restypes)} +restype_num = len(restypes) # := 20. +unk_restype_index = restype_num # Catch-all index for unknown restypes. + +restypes_with_x = restypes + ["X"] +restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)} + + +def sequence_to_onehot(sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False) -> np.ndarray: + """Maps the given sequence into a one-hot encoded matrix. + + Args: + sequence: An amino acid sequence. + mapping: A dictionary mapping amino acids to integers. + map_unknown_to_x: If True, any amino acid that is not in the mapping will be + mapped to the unknown amino acid 'X'. If the mapping doesn't contain amino acid 'X', an error will be thrown. + If False, any amino acid not in the mapping will throw an error. + + Returns: + A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of the sequence. + + Raises: + ValueError: If the mapping doesn't contain values from 0 to + num_unique_aas - 1 without any gaps. + """ + num_entries = max(mapping.values()) + 1 + + if sorted(set(mapping.values())) != list(range(num_entries)): + raise ValueError( + "The mapping must have values from 0 to num_unique_aas-1 without any gaps. Got: %s" + % sorted(mapping.values()) + ) + + one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32) + + for aa_index, aa_type in enumerate(sequence): + if map_unknown_to_x: + if aa_type.isalpha() and aa_type.isupper(): + aa_id = mapping.get(aa_type, mapping["X"]) + else: + raise ValueError(f"Invalid character in the sequence: {aa_type}") + else: + aa_id = mapping[aa_type] + one_hot_arr[aa_index, aa_id] = 1 + + return one_hot_arr + + +restype_1to3 = { + "A": "ALA", + "R": "ARG", + "N": "ASN", + "D": "ASP", + "C": "CYS", + "Q": "GLN", + "E": "GLU", + "G": "GLY", + "H": "HIS", + "I": "ILE", + "L": "LEU", + "K": "LYS", + "M": "MET", + "F": "PHE", + "P": "PRO", + "S": "SER", + "T": "THR", + "W": "TRP", + "Y": "TYR", + "V": "VAL", +} + + +# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple +# 1-to-1 mapping of 3 letter names to one letter names. The latter contains +# many more, and less common, three letter names as keys and maps many of these +# to the same one letter name (including 'X' and 'U' which we don't use here). +restype_3to1 = {v: k for k, v in restype_1to3.items()} + +# Define a restype name for all unknown residues. +unk_restype = "UNK" + +resnames = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx = {resname: i for i, resname in enumerate(resnames)} + + +# The mapping here uses hhblits convention, so that B is mapped to D, J and O +# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the +# remaining 20 amino acids are kept in alphabetical order. +# There are 2 non-amino acid codes, X (representing any amino acid) and +# "-" representing a missing amino acid in an alignment. The id for these +# codes is put at the end (20 and 21) so that they can easily be ignored if +# desired. +HHBLITS_AA_TO_ID = { + "A": 0, + "B": 2, + "C": 1, + "D": 2, + "E": 3, + "F": 4, + "G": 5, + "H": 6, + "I": 7, + "J": 20, + "K": 8, + "L": 9, + "M": 10, + "N": 11, + "O": 20, + "P": 12, + "Q": 13, + "R": 14, + "S": 15, + "T": 16, + "U": 1, + "V": 17, + "W": 18, + "X": 20, + "Y": 19, + "Z": 3, + "-": 21, +} + +# Partial inversion of HHBLITS_AA_TO_ID. +ID_TO_HHBLITS_AA = { + 0: "A", + 1: "C", # Also U. + 2: "D", # Also B. + 3: "E", # Also Z. + 4: "F", + 5: "G", + 6: "H", + 7: "I", + 8: "K", + 9: "L", + 10: "M", + 11: "N", + 12: "P", + 13: "Q", + 14: "R", + 15: "S", + 16: "T", + 17: "V", + 18: "W", + 19: "Y", + 20: "X", # Includes J and O. + 21: "-", +} + +restypes_with_x_and_gap = restypes + ["X", "-"] +MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple( + restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) for i in range(len(restypes_with_x_and_gap)) +) + + +def _make_standard_atom_mask() -> np.ndarray: + """Returns [num_res_types, num_atom_types] mask array.""" + # +1 to account for unknown (all 0s). + mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3[restype_letter] + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + mask[restype, atom_type] = 1 + return mask + + +STANDARD_ATOM_MASK = _make_standard_atom_mask() + + +# A one hot representation for the first and second atoms defining the axis +# of rotation for each chi-angle in each residue. +def chi_angle_atom(atom_index: int) -> np.ndarray: + """Define chi-angle rigid groups via one-hot representations.""" + chi_angles_index = {} + one_hots = [] + + for k, v in chi_angles_atoms.items(): + indices = [atom_types.index(s[atom_index]) for s in v] + indices.extend([-1] * (4 - len(indices))) + chi_angles_index[k] = indices + + for r in restypes: + res3 = restype_1to3[r] + one_hot = np.eye(atom_type_num)[chi_angles_index[res3]] + one_hots.append(one_hot) + + one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`. + one_hot = np.stack(one_hots, axis=0) + one_hot = np.transpose(one_hot, [0, 2, 1]) + + return one_hot + + +chi_atom_1_one_hot = chi_angle_atom(1) +chi_atom_2_one_hot = chi_angle_atom(2) + +# An array like chi_angles_atoms but using indices rather than names. +chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes] +chi_angles_atom_indices_ours = map_structure_with_atom_order(chi_angles_atom_indices) +chi_angles_atom_indices = np.array( + [chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) for chi_atoms in chi_angles_atom_indices] +) + +# Mapping from (res_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +chi_groups_for_atom = collections.defaultdict(list) +for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i)) +chi_groups_for_atom = dict(chi_groups_for_atom) + + +def _make_rigid_transformation_4x4(ex, ey, translation): + """Create a rigid 4x4 transformation matrix from two axes and transl.""" + # Normalize ex. + ex_normalized = ex / np.linalg.norm(ex) + + # make ey perpendicular to ex + ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized + ey_normalized /= np.linalg.norm(ey_normalized) + + # compute ez as cross product + eznorm = np.cross(ex_normalized, ey_normalized) + m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose() + m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0) + return m + + +# create an array with (restype, atomtype) --> rigid_group_idx +# and an array with (restype, atomtype, coord) for the atom positions +# and compute affine transformation matrices (4,4) from one rigid group to the +# previous group +restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int) +restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) +restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) +restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int) +restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) +restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) +restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) + + +def _make_rigid_group_constants(): + """Fill the arrays above.""" + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for atomname, group_idx, atom_position in rigid_group_atom_positions[resname]: + atomtype = atom_order[atomname] + restype_atom37_to_rigid_group[restype, atomtype] = group_idx + restype_atom37_mask[restype, atomtype] = 1 + restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position + + atom14idx = restype_name_to_atom14_names[resname].index(atomname) + restype_atom14_to_rigid_group[restype, atom14idx] = group_idx + restype_atom14_mask[restype, atom14idx] = 1 + restype_atom14_rigid_group_positions[restype, atom14idx, :] = atom_position + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_positions = {name: np.array(pos) for name, _, pos in rigid_group_atom_positions[resname]} + + # backbone to backbone is the identity transform + restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) + + # pre-omega-frame to backbone (currently dummy identity matrix) + restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) + + # phi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions["N"] - atom_positions["CA"], + ey=np.array([1.0, 0.0, 0.0]), + translation=atom_positions["N"], + ) + restype_rigid_group_default_frame[restype, 2, :, :] = mat + + # psi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions["C"] - atom_positions["CA"], + ey=atom_positions["CA"] - atom_positions["N"], + translation=atom_positions["C"], + ) + restype_rigid_group_default_frame[restype, 3, :, :] = mat + + # chi1-frame to backbone + if chi_angles_mask[restype][0]: + base_atom_names = chi_angles_atoms[resname][0] + base_atom_positions = [atom_positions[name] for name in base_atom_names] + mat = _make_rigid_transformation_4x4( + ex=base_atom_positions[2] - base_atom_positions[1], + ey=base_atom_positions[0] - base_atom_positions[1], + translation=base_atom_positions[2], + ) + restype_rigid_group_default_frame[restype, 4, :, :] = mat + + # chi2-frame to chi1-frame + # chi3-frame to chi2-frame + # chi4-frame to chi3-frame + # luckily all rotation axes for the next frame start at (0,0,0) of the + # previous frame + for chi_idx in range(1, 4): + if chi_angles_mask[restype][chi_idx]: + axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] + axis_end_atom_position = atom_positions[axis_end_atom_name] + mat = _make_rigid_transformation_4x4( + ex=axis_end_atom_position, + ey=np.array([-1.0, 0.0, 0.0]), + translation=axis_end_atom_position, + ) + restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat + + +_make_rigid_group_constants() + + +def make_atom14_dists_bounds(overlap_tolerance=1.5, bond_length_tolerance_factor=15): + """compute upper and lower bounds for bonds to assess violations.""" + restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32) + residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props() + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_list = restype_name_to_atom14_names[resname] + + # create lower and upper bounds for clashes + for atom1_idx, atom1_name in enumerate(atom_list): + if not atom1_name: + continue + atom1_radius = van_der_waals_radius[atom1_name[0]] + for atom2_idx, atom2_name in enumerate(atom_list): + if (not atom2_name) or atom1_idx == atom2_idx: + continue + atom2_radius = van_der_waals_radius[atom2_name[0]] + lower = atom1_radius + atom2_radius - overlap_tolerance + upper = 1e10 + restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper + + # overwrite lower and upper bounds for bonds and angles + for b in residue_bonds[resname] + residue_virtual_bonds[resname]: + atom1_idx = atom_list.index(b.atom1_name) + atom2_idx = atom_list.index(b.atom2_name) + lower = b.length - bond_length_tolerance_factor * b.stddev + upper = b.length + bond_length_tolerance_factor * b.stddev + restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper + restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev + restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev + return { + "lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14) + "upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14) + "stddev": restype_atom14_bond_stddev, # shape (21,14,14) + } + + +restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32) +restype_atom14_ambiguous_atoms_swap_idx = np.tile(np.arange(14, dtype=int), (21, 1)) + + +def _make_atom14_ambiguity_feats(): + for res, pairs in residue_atom_renaming_swaps.items(): + res_idx = restype_order[restype_3to1[res]] + for atom1, atom2 in pairs.items(): + atom1_idx = restype_name_to_atom14_names[res].index(atom1) + atom2_idx = restype_name_to_atom14_names[res].index(atom2) + restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1 + restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1 + restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom1_idx] = atom2_idx + restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom2_idx] = atom1_idx + + +_make_atom14_ambiguity_feats() + + +def aatype_to_str_sequence(aatype): + return "".join([restypes_with_x[aatype[i]] for i in range(len(aatype))]) diff --git a/src/transformers/models/esm/openfold_utils/rigid_utils.py b/src/transformers/models/esm/openfold_utils/rigid_utils.py new file mode 100644 index 00000000000000..c437cf7953f84b --- /dev/null +++ b/src/transformers/models/esm/openfold_utils/rigid_utils.py @@ -0,0 +1,1290 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from functools import lru_cache +from typing import Any, Callable, Optional, Sequence, Tuple + +import numpy as np +import torch + + +def rot_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Performs matrix multiplication of two rotation matrix tensors. Written out by hand to avoid AMP downcasting. + + Args: + a: [*, 3, 3] left multiplicand + b: [*, 3, 3] right multiplicand + Returns: + The product ab + """ + + def row_mul(i): + return torch.stack( + [ + a[..., i, 0] * b[..., 0, 0] + a[..., i, 1] * b[..., 1, 0] + a[..., i, 2] * b[..., 2, 0], + a[..., i, 0] * b[..., 0, 1] + a[..., i, 1] * b[..., 1, 1] + a[..., i, 2] * b[..., 2, 1], + a[..., i, 0] * b[..., 0, 2] + a[..., i, 1] * b[..., 1, 2] + a[..., i, 2] * b[..., 2, 2], + ], + dim=-1, + ) + + return torch.stack( + [ + row_mul(0), + row_mul(1), + row_mul(2), + ], + dim=-2, + ) + + +def rot_vec_mul(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + Applies a rotation to a vector. Written out by hand to avoid transfer to avoid AMP downcasting. + + Args: + r: [*, 3, 3] rotation matrices + t: [*, 3] coordinate tensors + Returns: + [*, 3] rotated coordinates + """ + x, y, z = torch.unbind(t, dim=-1) + return torch.stack( + [ + r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z, + r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z, + r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z, + ], + dim=-1, + ) + + +@lru_cache(maxsize=None) +def identity_rot_mats( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + rots = torch.eye(3, dtype=dtype, device=device, requires_grad=requires_grad) + rots = rots.view(*((1,) * len(batch_dims)), 3, 3) + rots = rots.expand(*batch_dims, -1, -1) + rots = rots.contiguous() + + return rots + + +@lru_cache(maxsize=None) +def identity_trans( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + trans = torch.zeros((*batch_dims, 3), dtype=dtype, device=device, requires_grad=requires_grad) + return trans + + +@lru_cache(maxsize=None) +def identity_quats( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + quat = torch.zeros((*batch_dims, 4), dtype=dtype, device=device, requires_grad=requires_grad) + + with torch.no_grad(): + quat[..., 0] = 1 + + return quat + + +_quat_elements = ["a", "b", "c", "d"] +_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements] +_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)} + + +def _to_mat(pairs): + mat = np.zeros((4, 4)) + for pair in pairs: + key, value = pair + ind = _qtr_ind_dict[key] + mat[ind // 4][ind % 4] = value + + return mat + + +_QTR_MAT = np.zeros((4, 4, 3, 3)) +_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)]) +_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)]) +_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)]) +_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)]) +_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)]) +_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)]) +_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)]) +_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)]) +_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)]) + + +def quat_to_rot(quat: torch.Tensor) -> torch.Tensor: + """ + Converts a quaternion to a rotation matrix. + + Args: + quat: [*, 4] quaternions + Returns: + [*, 3, 3] rotation matrices + """ + # [*, 4, 4] + quat = quat[..., None] * quat[..., None, :] + + # [4, 4, 3, 3] + mat = _get_quat("_QTR_MAT", dtype=quat.dtype, device=quat.device) + + # [*, 4, 4, 3, 3] + shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape) + quat = quat[..., None, None] * shaped_qtr_mat + + # [*, 3, 3] + return torch.sum(quat, dim=(-3, -4)) + + +def rot_to_quat( + rot: torch.Tensor, +): + if rot.shape[-2:] != (3, 3): + raise ValueError("Input rotation is incorrectly shaped") + + rot = [[rot[..., i, j] for j in range(3)] for i in range(3)] + [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot + + k = [ + [ + xx + yy + zz, + zy - yz, + xz - zx, + yx - xy, + ], + [ + zy - yz, + xx - yy - zz, + xy + yx, + xz + zx, + ], + [ + xz - zx, + xy + yx, + yy - xx - zz, + yz + zy, + ], + [ + yx - xy, + xz + zx, + yz + zy, + zz - xx - yy, + ], + ] + + k = (1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2) + + _, vectors = torch.linalg.eigh(k) + return vectors[..., -1] + + +_QUAT_MULTIPLY = np.zeros((4, 4, 4)) +_QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, -1]] + +_QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]] + +_QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], [0, 1, 0, 0]] + +_QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], [1, 0, 0, 0]] + +_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] + +_CACHED_QUATS = { + "_QTR_MAT": _QTR_MAT, + "_QUAT_MULTIPLY": _QUAT_MULTIPLY, + "_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC, +} + + +@lru_cache(maxsize=None) +def _get_quat(quat_key, dtype, device): + return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device) + + +def quat_multiply(quat1, quat2): + """Multiply a quaternion by another quaternion.""" + mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device) + reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape) + return torch.sum(reshaped_mat * quat1[..., :, None, None] * quat2[..., None, :, None], dim=(-3, -2)) + + +def quat_multiply_by_vec(quat, vec): + """Multiply a quaternion by a pure-vector quaternion.""" + mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device) + reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape) + return torch.sum(reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], dim=(-3, -2)) + + +def invert_rot_mat(rot_mat: torch.Tensor): + return rot_mat.transpose(-1, -2) + + +def invert_quat(quat: torch.Tensor): + quat_prime = quat.clone() + quat_prime[..., 1:] *= -1 + inv = quat_prime / torch.sum(quat**2, dim=-1, keepdim=True) + return inv + + +class Rotation: + """ + A 3D rotation. Depending on how the object is initialized, the rotation is represented by either a rotation matrix + or a quaternion, though both formats are made available by helper functions. To simplify gradient computation, the + underlying format of the rotation cannot be changed in-place. Like Rigid, the class is designed to mimic the + behavior of a torch Tensor, almost as if each Rotation object were a tensor of rotations, in one format or another. + """ + + def __init__( + self, + rot_mats: Optional[torch.Tensor] = None, + quats: Optional[torch.Tensor] = None, + normalize_quats: bool = True, + ): + """ + Args: + rot_mats: + A [*, 3, 3] rotation matrix tensor. Mutually exclusive with quats + quats: + A [*, 4] quaternion. Mutually exclusive with rot_mats. If normalize_quats is not True, must be a unit + quaternion + normalize_quats: + If quats is specified, whether to normalize quats + """ + if (rot_mats is None and quats is None) or (rot_mats is not None and quats is not None): + raise ValueError("Exactly one input argument must be specified") + + if (rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or (quats is not None and quats.shape[-1] != 4): + raise ValueError("Incorrectly shaped rotation matrix or quaternion") + + # Force full-precision + if quats is not None: + quats = quats.to(dtype=torch.float32) + if rot_mats is not None: + rot_mats = rot_mats.to(dtype=torch.float32) + + if quats is not None and normalize_quats: + quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True) + + self._rot_mats = rot_mats + self._quats = quats + + @staticmethod + def identity( + shape, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, + fmt: str = "quat", + ) -> Rotation: + """ + Returns an identity Rotation. + + Args: + shape: + The "shape" of the resulting Rotation object. See documentation for the shape property + dtype: + The torch dtype for the rotation + device: + The torch device for the new rotation + requires_grad: + Whether the underlying tensors in the new rotation object should require gradient computation + fmt: + One of "quat" or "rot_mat". Determines the underlying format of the new object's rotation + Returns: + A new identity rotation + """ + if fmt == "rot_mat": + rot_mats = identity_rot_mats( + shape, + dtype, + device, + requires_grad, + ) + return Rotation(rot_mats=rot_mats, quats=None) + elif fmt == "quat": + quats = identity_quats(shape, dtype, device, requires_grad) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError(f"Invalid format: f{fmt}") + + # Magic methods + + def __getitem__(self, index: Any) -> Rotation: + """ + Allows torch-style indexing over the virtual shape of the rotation object. See documentation for the shape + property. + + Args: + index: + A torch index. E.g. (1, 3, 2), or (slice(None,)) + Returns: + The indexed rotation + """ + if type(index) != tuple: + index = (index,) + + if self._rot_mats is not None: + rot_mats = self._rot_mats[index + (slice(None), slice(None))] + return Rotation(rot_mats=rot_mats) + elif self._quats is not None: + quats = self._quats[index + (slice(None),)] + return Rotation(quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def __mul__( + self, + right: torch.Tensor, + ) -> Rotation: + """ + Pointwise left multiplication of the rotation with a tensor. Can be used to e.g. mask the Rotation. + + Args: + right: + The tensor multiplicand + Returns: + The product + """ + if not (isinstance(right, torch.Tensor)): + raise TypeError("The other multiplicand must be a Tensor") + + if self._rot_mats is not None: + rot_mats = self._rot_mats * right[..., None, None] + return Rotation(rot_mats=rot_mats, quats=None) + elif self._quats is not None: + quats = self._quats * right[..., None] + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def __rmul__( + self, + left: torch.Tensor, + ) -> Rotation: + """ + Reverse pointwise multiplication of the rotation with a tensor. + + Args: + left: + The left multiplicand + Returns: + The product + """ + return self.__mul__(left) + + # Properties + + @property + def shape(self) -> torch.Size: + """ + Returns the virtual shape of the rotation object. This shape is defined as the batch dimensions of the + underlying rotation matrix or quaternion. If the Rotation was initialized with a [10, 3, 3] rotation matrix + tensor, for example, the resulting shape would be [10]. + + Returns: + The virtual shape of the rotation object + """ + s = None + if self._quats is not None: + s = self._quats.shape[:-1] + else: + s = self._rot_mats.shape[:-2] + + return s + + @property + def dtype(self) -> torch.dtype: + """ + Returns the dtype of the underlying rotation. + + Returns: + The dtype of the underlying rotation + """ + if self._rot_mats is not None: + return self._rot_mats.dtype + elif self._quats is not None: + return self._quats.dtype + else: + raise ValueError("Both rotations are None") + + @property + def device(self) -> torch.device: + """ + The device of the underlying rotation + + Returns: + The device of the underlying rotation + """ + if self._rot_mats is not None: + return self._rot_mats.device + elif self._quats is not None: + return self._quats.device + else: + raise ValueError("Both rotations are None") + + @property + def requires_grad(self) -> bool: + """ + Returns the requires_grad property of the underlying rotation + + Returns: + The requires_grad property of the underlying tensor + """ + if self._rot_mats is not None: + return self._rot_mats.requires_grad + elif self._quats is not None: + return self._quats.requires_grad + else: + raise ValueError("Both rotations are None") + + def get_rot_mats(self) -> torch.Tensor: + """ + Returns the underlying rotation as a rotation matrix tensor. + + Returns: + The rotation as a rotation matrix tensor + """ + rot_mats = self._rot_mats + if rot_mats is None: + if self._quats is None: + raise ValueError("Both rotations are None") + else: + rot_mats = quat_to_rot(self._quats) + + return rot_mats + + def get_quats(self) -> torch.Tensor: + """ + Returns the underlying rotation as a quaternion tensor. + + Depending on whether the Rotation was initialized with a quaternion, this function may call torch.linalg.eigh. + + Returns: + The rotation as a quaternion tensor. + """ + quats = self._quats + if quats is None: + if self._rot_mats is None: + raise ValueError("Both rotations are None") + else: + quats = rot_to_quat(self._rot_mats) + + return quats + + def get_cur_rot(self) -> torch.Tensor: + """ + Return the underlying rotation in its current form + + Returns: + The stored rotation + """ + if self._rot_mats is not None: + return self._rot_mats + elif self._quats is not None: + return self._quats + else: + raise ValueError("Both rotations are None") + + # Rotation functions + + def compose_q_update_vec(self, q_update_vec: torch.Tensor, normalize_quats: bool = True) -> Rotation: + """ + Returns a new quaternion Rotation after updating the current object's underlying rotation with a quaternion + update, formatted as a [*, 3] tensor whose final three columns represent x, y, z such that (1, x, y, z) is the + desired (not necessarily unit) quaternion update. + + Args: + q_update_vec: + A [*, 3] quaternion update tensor + normalize_quats: + Whether to normalize the output quaternion + Returns: + An updated Rotation + """ + quats = self.get_quats() + new_quats = quats + quat_multiply_by_vec(quats, q_update_vec) + return Rotation( + rot_mats=None, + quats=new_quats, + normalize_quats=normalize_quats, + ) + + def compose_r(self, r: Rotation) -> Rotation: + """ + Compose the rotation matrices of the current Rotation object with those of another. + + Args: + r: + An update rotation object + Returns: + An updated rotation object + """ + r1 = self.get_rot_mats() + r2 = r.get_rot_mats() + new_rot_mats = rot_matmul(r1, r2) + return Rotation(rot_mats=new_rot_mats, quats=None) + + def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation: + """ + Compose the quaternions of the current Rotation object with those of another. + + Depending on whether either Rotation was initialized with quaternions, this function may call + torch.linalg.eigh. + + Args: + r: + An update rotation object + Returns: + An updated rotation object + """ + q1 = self.get_quats() + q2 = r.get_quats() + new_quats = quat_multiply(q1, q2) + return Rotation(rot_mats=None, quats=new_quats, normalize_quats=normalize_quats) + + def apply(self, pts: torch.Tensor) -> torch.Tensor: + """ + Apply the current Rotation as a rotation matrix to a set of 3D coordinates. + + Args: + pts: + A [*, 3] set of points + Returns: + [*, 3] rotated points + """ + rot_mats = self.get_rot_mats() + return rot_vec_mul(rot_mats, pts) + + def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: + """ + The inverse of the apply() method. + + Args: + pts: + A [*, 3] set of points + Returns: + [*, 3] inverse-rotated points + """ + rot_mats = self.get_rot_mats() + inv_rot_mats = invert_rot_mat(rot_mats) + return rot_vec_mul(inv_rot_mats, pts) + + def invert(self) -> Rotation: + """ + Returns the inverse of the current Rotation. + + Returns: + The inverse of the current Rotation + """ + if self._rot_mats is not None: + return Rotation(rot_mats=invert_rot_mat(self._rot_mats), quats=None) + elif self._quats is not None: + return Rotation( + rot_mats=None, + quats=invert_quat(self._quats), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + # "Tensor" stuff + + def unsqueeze( + self, + dim: int, + ) -> Rigid: + """ + Analogous to torch.unsqueeze. The dimension is relative to the shape of the Rotation object. + + Args: + dim: A positive or negative dimension index. + Returns: + The unsqueezed Rotation. + """ + if dim >= len(self.shape): + raise ValueError("Invalid dimension") + + if self._rot_mats is not None: + rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2) + return Rotation(rot_mats=rot_mats, quats=None) + elif self._quats is not None: + quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + @staticmethod + def cat( + rs: Sequence[Rotation], + dim: int, + ) -> Rigid: + """ + Concatenates rotations along one of the batch dimensions. Analogous to torch.cat(). + + Note that the output of this operation is always a rotation matrix, regardless of the format of input + rotations. + + Args: + rs: + A list of rotation objects + dim: + The dimension along which the rotations should be concatenated + Returns: + A concatenated Rotation object in rotation matrix format + """ + rot_mats = [r.get_rot_mats() for r in rs] + rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2) + + return Rotation(rot_mats=rot_mats, quats=None) + + def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rotation: + """ + Apply a Tensor -> Tensor function to underlying rotation tensors, mapping over the rotation dimension(s). Can + be used e.g. to sum out a one-hot batch dimension. + + Args: + fn: + A Tensor -> Tensor function to be mapped over the Rotation + Returns: + The transformed Rotation object + """ + if self._rot_mats is not None: + rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,)) + rot_mats = torch.stack(list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1) + rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3)) + return Rotation(rot_mats=rot_mats, quats=None) + elif self._quats is not None: + quats = torch.stack(list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def cuda(self) -> Rotation: + """ + Analogous to the cuda() method of torch Tensors + + Returns: + A copy of the Rotation in CUDA memory + """ + if self._rot_mats is not None: + return Rotation(rot_mats=self._rot_mats.cuda(), quats=None) + elif self._quats is not None: + return Rotation(rot_mats=None, quats=self._quats.cuda(), normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def to(self, device: Optional[torch.device], dtype: Optional[torch.dtype]) -> Rotation: + """ + Analogous to the to() method of torch Tensors + + Args: + device: + A torch device + dtype: + A torch dtype + Returns: + A copy of the Rotation using the new device and dtype + """ + if self._rot_mats is not None: + return Rotation( + rot_mats=self._rot_mats.to(device=device, dtype=dtype), + quats=None, + ) + elif self._quats is not None: + return Rotation( + rot_mats=None, + quats=self._quats.to(device=device, dtype=dtype), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + def detach(self) -> Rotation: + """ + Returns a copy of the Rotation whose underlying Tensor has been detached from its torch graph. + + Returns: + A copy of the Rotation whose underlying Tensor has been detached from its torch graph + """ + if self._rot_mats is not None: + return Rotation(rot_mats=self._rot_mats.detach(), quats=None) + elif self._quats is not None: + return Rotation( + rot_mats=None, + quats=self._quats.detach(), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + +class Rigid: + """ + A class representing a rigid transformation. Little more than a wrapper around two objects: a Rotation object and a + [*, 3] translation Designed to behave approximately like a single torch tensor with the shape of the shared batch + dimensions of its component parts. + """ + + def __init__( + self, + rots: Optional[Rotation], + trans: Optional[torch.Tensor], + ): + """ + Args: + rots: A [*, 3, 3] rotation tensor + trans: A corresponding [*, 3] translation tensor + """ + # (we need device, dtype, etc. from at least one input) + + batch_dims, dtype, device, requires_grad = None, None, None, None + if trans is not None: + batch_dims = trans.shape[:-1] + dtype = trans.dtype + device = trans.device + requires_grad = trans.requires_grad + elif rots is not None: + batch_dims = rots.shape + dtype = rots.dtype + device = rots.device + requires_grad = rots.requires_grad + else: + raise ValueError("At least one input argument must be specified") + + if rots is None: + rots = Rotation.identity( + batch_dims, + dtype, + device, + requires_grad, + ) + elif trans is None: + trans = identity_trans( + batch_dims, + dtype, + device, + requires_grad, + ) + + if (rots.shape != trans.shape[:-1]) or (rots.device != trans.device): + raise ValueError("Rots and trans incompatible") + + # Force full precision. Happens to the rotations automatically. + trans = trans.to(dtype=torch.float32) + + self._rots = rots + self._trans = trans + + @staticmethod + def identity( + shape: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, + fmt: str = "quat", + ) -> Rigid: + """ + Constructs an identity transformation. + + Args: + shape: + The desired shape + dtype: + The dtype of both internal tensors + device: + The device of both internal tensors + requires_grad: + Whether grad should be enabled for the internal tensors + Returns: + The identity transformation + """ + return Rigid( + Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt), + identity_trans(shape, dtype, device, requires_grad), + ) + + def __getitem__( + self, + index: Any, + ) -> Rigid: + """ + Indexes the affine transformation with PyTorch-style indices. The index is applied to the shared dimensions of + both the rotation and the translation. + + E.g.:: + + r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None) t = Rigid(r, torch.rand(10, 10, 3)) indexed = + t[3, 4:6] assert(indexed.shape == (2,)) assert(indexed.get_rots().shape == (2,)) + assert(indexed.get_trans().shape == (2, 3)) + + Args: + index: A standard torch tensor index. E.g. 8, (10, None, 3), + or (3, slice(0, 1, None)) + Returns: + The indexed tensor + """ + if type(index) != tuple: + index = (index,) + + return Rigid( + self._rots[index], + self._trans[index + (slice(None),)], + ) + + def __mul__( + self, + right: torch.Tensor, + ) -> Rigid: + """ + Pointwise left multiplication of the transformation with a tensor. Can be used to e.g. mask the Rigid. + + Args: + right: + The tensor multiplicand + Returns: + The product + """ + if not (isinstance(right, torch.Tensor)): + raise TypeError("The other multiplicand must be a Tensor") + + new_rots = self._rots * right + new_trans = self._trans * right[..., None] + + return Rigid(new_rots, new_trans) + + def __rmul__( + self, + left: torch.Tensor, + ) -> Rigid: + """ + Reverse pointwise multiplication of the transformation with a tensor. + + Args: + left: + The left multiplicand + Returns: + The product + """ + return self.__mul__(left) + + @property + def shape(self) -> torch.Size: + """ + Returns the shape of the shared dimensions of the rotation and the translation. + + Returns: + The shape of the transformation + """ + s = self._trans.shape[:-1] + return s + + @property + def device(self) -> torch.device: + """ + Returns the device on which the Rigid's tensors are located. + + Returns: + The device on which the Rigid's tensors are located + """ + return self._trans.device + + def get_rots(self) -> Rotation: + """ + Getter for the rotation. + + Returns: + The rotation object + """ + return self._rots + + def get_trans(self) -> torch.Tensor: + """ + Getter for the translation. + + Returns: + The stored translation + """ + return self._trans + + def compose_q_update_vec( + self, + q_update_vec: torch.Tensor, + ) -> Rigid: + """ + Composes the transformation with a quaternion update vector of shape [*, 6], where the final 6 columns + represent the x, y, and z values of a quaternion of form (1, x, y, z) followed by a 3D translation. + + Args: + q_vec: The quaternion update vector. + Returns: + The composed transformation. + """ + q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:] + new_rots = self._rots.compose_q_update_vec(q_vec) + + trans_update = self._rots.apply(t_vec) + new_translation = self._trans + trans_update + + return Rigid(new_rots, new_translation) + + def compose( + self, + r: Rigid, + ) -> Rigid: + """ + Composes the current rigid object with another. + + Args: + r: + Another Rigid object + Returns: + The composition of the two transformations + """ + new_rot = self._rots.compose_r(r._rots) + new_trans = self._rots.apply(r._trans) + self._trans + return Rigid(new_rot, new_trans) + + def apply( + self, + pts: torch.Tensor, + ) -> torch.Tensor: + """ + Applies the transformation to a coordinate tensor. + + Args: + pts: A [*, 3] coordinate tensor. + Returns: + The transformed points. + """ + rotated = self._rots.apply(pts) + return rotated + self._trans + + def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: + """ + Applies the inverse of the transformation to a coordinate tensor. + + Args: + pts: A [*, 3] coordinate tensor + Returns: + The transformed points. + """ + pts = pts - self._trans + return self._rots.invert_apply(pts) + + def invert(self) -> Rigid: + """ + Inverts the transformation. + + Returns: + The inverse transformation. + """ + rot_inv = self._rots.invert() + trn_inv = rot_inv.apply(self._trans) + + return Rigid(rot_inv, -1 * trn_inv) + + def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid: + """ + Apply a Tensor -> Tensor function to underlying translation and rotation tensors, mapping over the + translation/rotation dimensions respectively. + + Args: + fn: + A Tensor -> Tensor function to be mapped over the Rigid + Returns: + The transformed Rigid object + """ + new_rots = self._rots.map_tensor_fn(fn) + new_trans = torch.stack(list(map(fn, torch.unbind(self._trans, dim=-1))), dim=-1) + + return Rigid(new_rots, new_trans) + + def to_tensor_4x4(self) -> torch.Tensor: + """ + Converts a transformation to a homogenous transformation tensor. + + Returns: + A [*, 4, 4] homogenous transformation tensor + """ + tensor = self._trans.new_zeros((*self.shape, 4, 4)) + tensor[..., :3, :3] = self._rots.get_rot_mats() + tensor[..., :3, 3] = self._trans + tensor[..., 3, 3] = 1 + return tensor + + @staticmethod + def from_tensor_4x4(t: torch.Tensor) -> Rigid: + """ + Constructs a transformation from a homogenous transformation tensor. + + Args: + t: [*, 4, 4] homogenous transformation tensor + Returns: + T object with shape [*] + """ + if t.shape[-2:] != (4, 4): + raise ValueError("Incorrectly shaped input tensor") + + rots = Rotation(rot_mats=t[..., :3, :3], quats=None) + trans = t[..., :3, 3] + + return Rigid(rots, trans) + + def to_tensor_7(self) -> torch.Tensor: + """ + Converts a transformation to a tensor with 7 final columns, four for the quaternion followed by three for the + translation. + + Returns: + A [*, 7] tensor representation of the transformation + """ + tensor = self._trans.new_zeros((*self.shape, 7)) + tensor[..., :4] = self._rots.get_quats() + tensor[..., 4:] = self._trans + + return tensor + + @staticmethod + def from_tensor_7( + t: torch.Tensor, + normalize_quats: bool = False, + ) -> Rigid: + if t.shape[-1] != 7: + raise ValueError("Incorrectly shaped input tensor") + + quats, trans = t[..., :4], t[..., 4:] + + rots = Rotation(rot_mats=None, quats=quats, normalize_quats=normalize_quats) + + return Rigid(rots, trans) + + @staticmethod + def from_3_points( + p_neg_x_axis: torch.Tensor, origin: torch.Tensor, p_xy_plane: torch.Tensor, eps: float = 1e-8 + ) -> Rigid: + """ + Implements algorithm 21. Constructs transformations from sets of 3 points using the Gram-Schmidt algorithm. + + Args: + p_neg_x_axis: [*, 3] coordinates + origin: [*, 3] coordinates used as frame origins + p_xy_plane: [*, 3] coordinates + eps: Small epsilon value + Returns: + A transformation object of shape [*] + """ + p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1) + origin = torch.unbind(origin, dim=-1) + p_xy_plane = torch.unbind(p_xy_plane, dim=-1) + + e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)] + e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)] + + denom = torch.sqrt(sum((c * c for c in e0)) + eps) + e0 = [c / denom for c in e0] + dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) + e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] + denom = torch.sqrt(sum((c * c for c in e1)) + eps) + e1 = [c / denom for c in e1] + e2 = [ + e0[1] * e1[2] - e0[2] * e1[1], + e0[2] * e1[0] - e0[0] * e1[2], + e0[0] * e1[1] - e0[1] * e1[0], + ] + + rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) + rots = rots.reshape(rots.shape[:-1] + (3, 3)) + + rot_obj = Rotation(rot_mats=rots, quats=None) + + return Rigid(rot_obj, torch.stack(origin, dim=-1)) + + def unsqueeze( + self, + dim: int, + ) -> Rigid: + """ + Analogous to torch.unsqueeze. The dimension is relative to the shared dimensions of the rotation/translation. + + Args: + dim: A positive or negative dimension index. + Returns: + The unsqueezed transformation. + """ + if dim >= len(self.shape): + raise ValueError("Invalid dimension") + rots = self._rots.unsqueeze(dim) + trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1) + + return Rigid(rots, trans) + + @staticmethod + def cat( + ts: Sequence[Rigid], + dim: int, + ) -> Rigid: + """ + Concatenates transformations along a new dimension. + + Args: + ts: + A list of T objects + dim: + The dimension along which the transformations should be concatenated + Returns: + A concatenated transformation object + """ + rots = Rotation.cat([t._rots for t in ts], dim) + trans = torch.cat([t._trans for t in ts], dim=dim if dim >= 0 else dim - 1) + + return Rigid(rots, trans) + + def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid: + """ + Applies a Rotation -> Rotation function to the stored rotation object. + + Args: + fn: A function of type Rotation -> Rotation + Returns: + A transformation object with a transformed rotation. + """ + return Rigid(fn(self._rots), self._trans) + + def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid: + """ + Applies a Tensor -> Tensor function to the stored translation. + + Args: + fn: + A function of type Tensor -> Tensor to be applied to the translation + Returns: + A transformation object with a transformed translation. + """ + return Rigid(self._rots, fn(self._trans)) + + def scale_translation(self, trans_scale_factor: float) -> Rigid: + """ + Scales the translation by a constant factor. + + Args: + trans_scale_factor: + The constant factor + Returns: + A transformation object with a scaled translation. + """ + return self.apply_trans_fn(lambda t: t * trans_scale_factor) + + def stop_rot_gradient(self) -> Rigid: + """ + Detaches the underlying rotation object + + Returns: + A transformation object with detached rotations + """ + return self.apply_rot_fn(lambda r: r.detach()) + + @staticmethod + def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): + """ + Returns a transformation object from reference coordinates. + + Note that this method does not take care of symmetries. If you provide the atom positions in the non-standard + way, the N atom will end up not at [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You + need to take care of such cases in your code. + + Args: + n_xyz: A [*, 3] tensor of nitrogen xyz coordinates. + ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates. + c_xyz: A [*, 3] tensor of carbon xyz coordinates. + Returns: + A transformation object. After applying the translation and rotation to the reference backbone, the + coordinates will approximately equal to the input coordinates. + """ + translation = -1 * ca_xyz + n_xyz = n_xyz + translation + c_xyz = c_xyz + translation + + c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + c_x**2 + c_y**2) + sin_c1 = -c_y / norm + cos_c1 = c_x / norm + + c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3)) + c1_rots[..., 0, 0] = cos_c1 + c1_rots[..., 0, 1] = -1 * sin_c1 + c1_rots[..., 1, 0] = sin_c1 + c1_rots[..., 1, 1] = cos_c1 + c1_rots[..., 2, 2] = 1 + + norm = torch.sqrt(eps + c_x**2 + c_y**2 + c_z**2) + sin_c2 = c_z / norm + cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm + + c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + c2_rots[..., 0, 0] = cos_c2 + c2_rots[..., 0, 2] = sin_c2 + c2_rots[..., 1, 1] = 1 + c2_rots[..., 2, 0] = -1 * sin_c2 + c2_rots[..., 2, 2] = cos_c2 + + c_rots = rot_matmul(c2_rots, c1_rots) + n_xyz = rot_vec_mul(c_rots, n_xyz) + + _, n_y, n_z = [n_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + n_y**2 + n_z**2) + sin_n = -n_z / norm + cos_n = n_y / norm + + n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + n_rots[..., 0, 0] = 1 + n_rots[..., 1, 1] = cos_n + n_rots[..., 1, 2] = -1 * sin_n + n_rots[..., 2, 1] = sin_n + n_rots[..., 2, 2] = cos_n + + rots = rot_matmul(n_rots, c_rots) + + rots = rots.transpose(-1, -2) + translation = -1 * translation + + rot_obj = Rotation(rot_mats=rots, quats=None) + + return Rigid(rot_obj, translation) + + def cuda(self) -> Rigid: + """ + Moves the transformation object to GPU memory + + Returns: + A version of the transformation on GPU + """ + return Rigid(self._rots.cuda(), self._trans.cuda()) diff --git a/src/transformers/models/esm/openfold_utils/tensor_utils.py b/src/transformers/models/esm/openfold_utils/tensor_utils.py new file mode 100644 index 00000000000000..60e8b3f2146614 --- /dev/null +++ b/src/transformers/models/esm/openfold_utils/tensor_utils.py @@ -0,0 +1,116 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import List + +import torch +import torch.nn as nn + + +def add(m1, m2, inplace): + # The first operation in a checkpoint can't be in-place, but it's + # nice to have in-place addition during inference. Thus... + if not inplace: + m1 = m1 + m2 + else: + m1 += m2 + + return m1 + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +def flatten_final_dims(t: torch.Tensor, no_dims: int): + return t.reshape(t.shape[:-no_dims] + (-1,)) + + +def masked_mean(mask, value, dim, eps=1e-4): + mask = mask.expand(*value.shape) + return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) + + +def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): + boundaries = torch.linspace(min_bin, max_bin, no_bins - 1, device=pts.device) + dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)) + return torch.bucketize(dists, boundaries) + + +def dict_multimap(fn, dicts): + first = dicts[0] + new_dict = {} + for k, v in first.items(): + all_v = [d[k] for d in dicts] + if type(v) is dict: + new_dict[k] = dict_multimap(fn, all_v) + else: + new_dict[k] = fn(all_v) + + return new_dict + + +def one_hot(x, v_bins): + reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) + diffs = x[..., None] - reshaped_bins + am = torch.argmin(torch.abs(diffs), dim=-1) + return nn.functional.one_hot(am, num_classes=len(v_bins)).float() + + +def batched_gather(data, inds, dim=0, no_batch_dims=0): + ranges = [] + for i, s in enumerate(data.shape[:no_batch_dims]): + r = torch.arange(s) + r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) + ranges.append(r) + + remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] + remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds + ranges.extend(remaining_dims) + # Matt note: Editing this to get around the behaviour of using a list as an array index changing + # in recent Numpy versions + return data[tuple(ranges)] + + +# With tree_map, a poor man's JAX tree_map +def dict_map(fn, dic, leaf_type): + new_dict = {} + for k, v in dic.items(): + if type(v) is dict: + new_dict[k] = dict_map(fn, v, leaf_type) + else: + new_dict[k] = tree_map(fn, v, leaf_type) + + return new_dict + + +def tree_map(fn, tree, leaf_type): + if isinstance(tree, dict): + return dict_map(fn, tree, leaf_type) + elif isinstance(tree, list): + return [tree_map(fn, x, leaf_type) for x in tree] + elif isinstance(tree, tuple): + return tuple([tree_map(fn, x, leaf_type) for x in tree]) + elif isinstance(tree, leaf_type): + return fn(tree) + else: + print(type(tree)) + raise ValueError("Not supported") + + +tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 52f3d39b5d8efe..cb2f93be0fc905 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1985,6 +1985,13 @@ def __init__(self, *args, **kwargs): ESM_PRETRAINED_MODEL_ARCHIVE_LIST = None +class EsmFoldPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class EsmForMaskedLM(metaclass=DummyObject): _backends = ["torch"] @@ -1992,6 +1999,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class EsmForProteinFolding(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class EsmForSequenceClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/esm/test_modeling_esm.py b/tests/models/esm/test_modeling_esm.py index 9733223e39c73d..d8ec7929f589cd 100644 --- a/tests/models/esm/test_modeling_esm.py +++ b/tests/models/esm/test_modeling_esm.py @@ -20,7 +20,6 @@ from transformers import EsmConfig, is_torch_available from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device -from ...generation.test_generation_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask @@ -49,7 +48,7 @@ def __init__( self.use_input_mask = True self.use_token_type_ids = False self.use_labels = True - self.vocab_size = 99 + self.vocab_size = 33 self.hidden_size = 32 self.num_hidden_layers = 5 self.num_attention_heads = 4 @@ -145,7 +144,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): +class EsmModelTest(ModelTesterMixin, unittest.TestCase): test_mismatched_shapes = False @@ -253,28 +252,32 @@ def test_resize_tokens_embeddings(self): class EsmModelIntegrationTest(TestCasePlus): @slow def test_inference_masked_lm(self): - model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") - input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) - output = model(input_ids)[0] + with torch.no_grad(): + model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") + model.eval() + input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) + output = model(input_ids)[0] - vocab_size = 33 + vocab_size = 33 - expected_shape = torch.Size((1, 6, vocab_size)) - self.assertEqual(output.shape, expected_shape) + expected_shape = torch.Size((1, 6, vocab_size)) + self.assertEqual(output.shape, expected_shape) - expected_slice = torch.tensor( - [[[15.0973, -6.6406, -1.1351], [-0.2209, -9.9622, 4.2109], [-1.6055, -10.0023, 1.5914]]] - ) - self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) + expected_slice = torch.tensor( + [[[15.0973, -6.6406, -1.1351], [-0.2209, -9.9622, 4.2109], [-1.6055, -10.0023, 1.5914]]] + ) + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) @slow def test_inference_no_head(self): - model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") - - input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) - output = model(input_ids)[0] - # compare the actual values for a slice. - expected_slice = torch.tensor( - [[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]] - ) - self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) + with torch.no_grad(): + model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") + model.eval() + + input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) + output = model(input_ids)[0] + # compare the actual values for a slice. + expected_slice = torch.tensor( + [[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]] + ) + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) diff --git a/tests/models/esm/test_modeling_esmfold.py b/tests/models/esm/test_modeling_esmfold.py new file mode 100644 index 00000000000000..5813e782974299 --- /dev/null +++ b/tests/models/esm/test_modeling_esmfold.py @@ -0,0 +1,234 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch ESM model. """ + + +import unittest + +from transformers import EsmConfig, is_torch_available +from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask + + +if is_torch_available(): + import torch + + from transformers.models.esm.modeling_esmfold import EsmForProteinFolding + + +class EsmFoldModelTester: + def __init__( + self, + parent, + ): + self.parent = parent + self.batch_size = 13 + self.seq_length = 7 + self.is_training = False + self.use_input_mask = True + self.use_token_type_ids = False + self.use_labels = False + self.vocab_size = 19 + self.hidden_size = 32 + self.num_hidden_layers = 5 + self.num_attention_heads = 4 + self.intermediate_size = 37 + self.hidden_act = "gelu" + self.hidden_dropout_prob = 0.1 + self.attention_probs_dropout_prob = 0.1 + self.max_position_embeddings = 512 + self.type_vocab_size = 16 + self.type_sequence_label_size = 2 + self.initializer_range = 0.02 + self.num_labels = 3 + self.num_choices = 4 + self.scope = None + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + config = EsmConfig( + vocab_size=33, + hidden_size=self.hidden_size, + pad_token_id=1, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + is_folding_model=True, + esmfold_config={"trunk": {"num_blocks": 2}, "fp16_esm": False}, + ) + return config + + def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): + model = EsmForProteinFolding(config=config).float() + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + result = model(input_ids) + + self.parent.assertEqual(result.positions.shape, (8, self.batch_size, self.seq_length, 14, 3)) + self.parent.assertEqual(result.angles.shape, (8, self.batch_size, self.seq_length, 7, 2)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class EsmFoldModelTest(ModelTesterMixin, unittest.TestCase): + + test_mismatched_shapes = False + + all_model_classes = (EsmForProteinFolding,) if is_torch_available() else () + all_generative_model_classes = () + test_sequence_classification_problem_types = False + + def setUp(self): + self.model_tester = EsmFoldModelTester(self) + self.config_tester = ConfigTester(self, config_class=EsmConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip("Does not support attention outputs") + def test_attention_outputs(self): + pass + + @unittest.skip + def test_correct_missing_keys(self): + pass + + @unittest.skip("Esm does not support embedding resizing") + def test_resize_embeddings_untied(self): + pass + + @unittest.skip("Esm does not support embedding resizing") + def test_resize_tokens_embeddings(self): + pass + + @unittest.skip("ESMFold does not support passing input embeds!") + def test_inputs_embeds(self): + pass + + @unittest.skip("ESMFold does not support head pruning.") + def test_head_pruning(self): + pass + + @unittest.skip("ESMFold does not support head pruning.") + def test_head_pruning_integration(self): + pass + + @unittest.skip("ESMFold does not support head pruning.") + def test_head_pruning_save_load_from_config_init(self): + pass + + @unittest.skip("ESMFold does not support head pruning.") + def test_head_pruning_save_load_from_pretrained(self): + pass + + @unittest.skip("ESMFold does not support head pruning.") + def test_headmasking(self): + pass + + @unittest.skip("ESMFold does not output hidden states in the normal way.") + def test_hidden_states_output(self): + pass + + @unittest.skip("ESMfold does not output hidden states in the normal way.") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip("ESMFold only has one output format.") + def test_model_outputs_equivalence(self): + pass + + @unittest.skip("This test doesn't work for ESMFold and doesn't test core functionality") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip("ESMFold does not support input chunking.") + def test_feed_forward_chunking(self): + pass + + @unittest.skip("ESMFold doesn't respect you and it certainly doesn't respect your initialization arguments.") + def test_initialization(self): + pass + + @unittest.skip("ESMFold doesn't support torchscript compilation.") + def test_torchscript_output_attentions(self): + pass + + @unittest.skip("ESMFold doesn't support torchscript compilation.") + def test_torchscript_output_hidden_state(self): + pass + + @unittest.skip("ESMFold doesn't support torchscript compilation.") + def test_torchscript_simple(self): + pass + + @unittest.skip("ESMFold doesn't support data parallel.") + def test_multi_gpu_data_parallel_forward(self): + pass + + +@require_torch +class EsmModelIntegrationTest(TestCasePlus): + @slow + def test_inference_protein_folding(self): + model = EsmForProteinFolding.from_pretrained("Rocketknight1/esmfold_v1").float() + model.eval() + input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) + position_outputs = model(input_ids)["positions"] + expected_slice = torch.tensor([2.5828, 0.7993, -10.9334], dtype=torch.float32) + self.assertTrue(torch.allclose(position_outputs[0, 0, 0, 0], expected_slice, atol=1e-4)) diff --git a/tests/models/esm/test_modeling_tf_esm.py b/tests/models/esm/test_modeling_tf_esm.py index 9a9aa8b1ac8eac..513dbb1a7be347 100644 --- a/tests/models/esm/test_modeling_tf_esm.py +++ b/tests/models/esm/test_modeling_tf_esm.py @@ -254,7 +254,7 @@ def test_save_load_after_resize_token_embeddings(self): @require_tf class TFEsmModelIntegrationTest(unittest.TestCase): - @slow + @unittest.skip("Temporarily disabled as we update ESM model checkpoints") def test_inference_masked_lm(self): model = TFEsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") @@ -268,7 +268,7 @@ def test_inference_masked_lm(self): ) self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4)) - @slow + @unittest.skip("Temporarily disabled as we update ESM model checkpoints") def test_inference_no_head(self): model = TFEsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") diff --git a/utils/check_inits.py b/utils/check_inits.py index 98d4caf010216b..9495746c9f4448 100644 --- a/utils/check_inits.py +++ b/utils/check_inits.py @@ -268,6 +268,7 @@ def get_transformers_submodules(): IGNORE_SUBMODULES = [ "convert_pytorch_checkpoint_to_tf2", "modeling_flax_pytorch_utils", + "models.esm.openfold_utils", ] diff --git a/utils/check_repo.py b/utils/check_repo.py index 6e89d97b9c8747..4b7ec38e80799a 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -140,6 +140,7 @@ # should **not** be the rule. IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ # models to ignore for model xxx mapping + "EsmForProteinFolding", "TimeSeriesTransformerForPrediction", "PegasusXEncoder", "PegasusXDecoder",