diff --git a/src/transformers/models/esm/modeling_esmfold.py b/src/transformers/models/esm/modeling_esmfold.py index 244d78a174c196..b47fe07ef9bf5b 100644 --- a/src/transformers/models/esm/modeling_esmfold.py +++ b/src/transformers/models/esm/modeling_esmfold.py @@ -2040,7 +2040,7 @@ def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles): s_z_0 = pair_feats if no_recycles is None: - no_recycles = self.cfg.max_recycles + no_recycles = self.config.max_recycles else: if no_recycles < 0: raise ValueError("Number of recycles must not be negative.") @@ -2190,7 +2190,7 @@ def _af2_to_esm_from_vocab_list(vocab_list: List[str]) -> torch.Tensor: def forward( self, input_ids: torch.Tensor, - attention_mask: torch.Tensor, + attention_mask: torch.Tensor = None, position_ids: Optional[torch.Tensor] = None, masking_pattern: Optional[torch.Tensor] = None, num_recycles: Optional[int] = None, @@ -2201,6 +2201,8 @@ def forward( B = aa.shape[0] L = aa.shape[1] device = input_ids.device + if attention_mask is None: + attention_mask = torch.ones_like(aa, device=device) if position_ids is None: position_ids = torch.arange(L, device=device).expand_as(input_ids)