diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 0772b57c76276..c4861f6bb5d9f 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -221,7 +221,7 @@ def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) if self.has_relative_attention_bias: - self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, 32 ) #self.n_heads) self.pruned_heads = set() self.gradient_checkpointing = False