From 80fe00d6069cc27c047dd0a1b768de0451e4aeb7 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 31 Oct 2022 17:23:27 +0000 Subject: [PATCH] Make attention_mask optional (default to all 1s) --- src/transformers/models/esm/modeling_esmfold.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/esm/modeling_esmfold.py b/src/transformers/models/esm/modeling_esmfold.py index 244d78a174c196..b47fe07ef9bf5b 100644 --- a/src/transformers/models/esm/modeling_esmfold.py +++ b/src/transformers/models/esm/modeling_esmfold.py @@ -2040,7 +2040,7 @@ def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles): s_z_0 = pair_feats if no_recycles is None: - no_recycles = self.cfg.max_recycles + no_recycles = self.config.max_recycles else: if no_recycles < 0: raise ValueError("Number of recycles must not be negative.") @@ -2190,7 +2190,7 @@ def _af2_to_esm_from_vocab_list(vocab_list: List[str]) -> torch.Tensor: def forward( self, input_ids: torch.Tensor, - attention_mask: torch.Tensor, + attention_mask: torch.Tensor = None, position_ids: Optional[torch.Tensor] = None, masking_pattern: Optional[torch.Tensor] = None, num_recycles: Optional[int] = None, @@ -2201,6 +2201,8 @@ def forward( B = aa.shape[0] L = aa.shape[1] device = input_ids.device + if attention_mask is None: + attention_mask = torch.ones_like(aa, device=device) if position_ids is None: position_ids = torch.arange(L, device=device).expand_as(input_ids)