Skip to content

Commit

Permalink
Fix T5 incorrect weight decay in Trainer and official summarization e…
Browse files Browse the repository at this point in the history
…xample (#18002)

* Add ALL_LAYERNORM_LAYERS for LayerNorm

* fix bug of appending layer norm
  • Loading branch information
ADAning committed Jul 6, 2022
1 parent 22edb68 commit bf37e5c
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def postprocess_text(preds, labels):

# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/longt5/modeling_longt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
DUMMY_INPUTS,
DUMMY_MASK,
Expand Down Expand Up @@ -260,6 +261,8 @@ def forward(self, hidden_states):
logger.warning("discovered apex but it failed to load, falling back to LongT5LayerNorm")
pass

ALL_LAYERNORM_LAYERS.append(LongT5LayerNorm)


# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->LongT5
class LongT5DenseActDense(nn.Module):
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
DUMMY_INPUTS,
DUMMY_MASK,
Expand Down Expand Up @@ -275,6 +275,8 @@ def forward(self, hidden_states):
logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
pass

ALL_LAYERNORM_LAYERS.append(T5LayerNorm)


class T5DenseActDense(nn.Module):
def __init__(self, config: T5Config):
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from .utils import logging


ALL_LAYERNORM_LAYERS = [nn.LayerNorm]

logger = logging.get_logger(__name__)

is_torch_less_than_1_8 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.8.0")
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
Expand Down Expand Up @@ -967,7 +968,7 @@ def create_optimizer(self):
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, [nn.LayerNorm])
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
Expand Down

0 comments on commit bf37e5c

Please sign in to comment.