Skip to content

Commit

Permalink
Make attention_mask optional (default to all 1s)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Oct 31, 2022
1 parent 64af806 commit 80fe00d
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/transformers/models/esm/modeling_esmfold.py
Expand Up @@ -2040,7 +2040,7 @@ def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles):
s_z_0 = pair_feats

if no_recycles is None:
no_recycles = self.cfg.max_recycles
no_recycles = self.config.max_recycles
else:
if no_recycles < 0:
raise ValueError("Number of recycles must not be negative.")
Expand Down Expand Up @@ -2190,7 +2190,7 @@ def _af2_to_esm_from_vocab_list(vocab_list: List[str]) -> torch.Tensor:
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
attention_mask: torch.Tensor = None,
position_ids: Optional[torch.Tensor] = None,
masking_pattern: Optional[torch.Tensor] = None,
num_recycles: Optional[int] = None,
Expand All @@ -2201,6 +2201,8 @@ def forward(
B = aa.shape[0]
L = aa.shape[1]
device = input_ids.device
if attention_mask is None:
attention_mask = torch.ones_like(aa, device=device)
if position_ids is None:
position_ids = torch.arange(L, device=device).expand_as(input_ids)

Expand Down

0 comments on commit 80fe00d

Please sign in to comment.