From 0ccb78f71393792a544c3d3d71b557b3b994e065 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Mon, 31 Oct 2022 18:01:00 +0100 Subject: [PATCH] Add support for gradient checkpointing --- .../modeling_bert_generation.py | 5 +++++ .../modeling_encoder_decoder.py | 7 +++++++ .../test_modeling_encoder_decoder.py | 21 +++++++++++++++++++ 3 files changed, 33 insertions(+) diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 3fc06450afb86..0ea49f0d07995 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -581,6 +581,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel): config_class = BertGenerationConfig base_model_prefix = "bert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -599,6 +600,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + BERT_GENERATION_START_DOCSTRING = r""" diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 9c340559f34a8..ed41e6a14e8ff 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -175,6 +175,8 @@ class EncoderDecoderModel(PreTrainedModel): """ config_class = EncoderDecoderConfig base_model_prefix = "encoder_decoder" + main_input_name = "input_ids" + supports_gradient_checkpointing = True def __init__( self, @@ -255,6 +257,11 @@ def tie_weights(self): self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix ) + def _set_gradient_checkpointing(self, module, value=False): + # call both encoder and decoder function on gradient checkpointing + self.encoder._set_gradient_checkpointing(module, value=value) + self.decoder._set_gradient_checkpointing(module, value=value) + def get_encoder(self): return self.encoder diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py index 32cae5066669b..1181b94789e4e 100644 --- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py @@ -611,6 +611,27 @@ def test_encoder_decoder_model_shared_weights(self): input_ids_dict = self.prepare_config_and_inputs() self.create_and_check_encoder_decoder_shared_weights(**input_ids_dict) + def test_training_gradient_checkpointing(self): + inputs_dict = self.prepare_config_and_inputs() + encoder_model, decoder_model = self.get_encoder_decoder_model( + inputs_dict["config"], inputs_dict["decoder_config"] + ) + + model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + model.train() + model.gradient_checkpointing_enable() + model.config.decoder_start_token_id = 0 + model.config.pad_token_id = 0 + + model_inputs = { + "input_ids": inputs_dict["input_ids"], + "attention_mask": inputs_dict["attention_mask"], + "labels": inputs_dict["labels"], + "decoder_input_ids": inputs_dict["decoder_input_ids"], + } + loss = model(**model_inputs).loss + loss.backward() + @slow def test_real_model_save_load_from_pretrained(self): model_2 = self.get_pretrained_model()