Skip to content

Commit

Permalink
Remove RuntimeErrors for NaN-checking in 20B (huggingface#17563)
Browse files Browse the repository at this point in the history
  • Loading branch information
zphang authored and Narsil committed Jun 7, 2022
1 parent f3f2860 commit 65e538b
Showing 1 changed file with 0 additions and 6 deletions.
6 changes: 0 additions & 6 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,6 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.einsum("bik,bjk->bij", query, key) / self.norm_factor
if torch.isnan(attn_scores).any():
raise RuntimeError()
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)

attn_scores = torch.where(causal_mask, attn_scores, self.masked_bias.to(attn_scores.dtype))
Expand All @@ -204,17 +202,13 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_scores = attn_scores + attention_mask

attn_weights = nn.functional.softmax(attn_scores, dim=-1)
if torch.isnan(attn_weights).any():
raise RuntimeError()
attn_weights = attn_weights.to(value.dtype)

# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask

attn_output = torch.matmul(attn_weights, value)
if torch.isnan(attn_output).any():
raise RuntimeError()
return attn_output, attn_weights


Expand Down

0 comments on commit 65e538b

Please sign in to comment.