Skip to content

Commit

Permalink
Bart: check if decoder_inputs_embeds is set (#13800)
Browse files Browse the repository at this point in the history
In BartForConditionalGeneration.forward, if labels are provided,
   decoder_input_ids are set to the labels shifted to the right.
   This is problematic: if decoder_inputs_embeds is also set,
   the call to self.model, which eventually gets to BartDecoder.forward,
   will raise an error.
   The fix is quite simple, similar to what is there already in
   BartModel.forward. Mainly, we should not
   compute decoder_input_ids if decoder_inputs_embeds is provided.

Co-authored-by: Silviu Vlad Oprea <silviuvo@amazon.co.uk>
  • Loading branch information
Silviu Oprea and Silviu Vlad Oprea committed Oct 1, 2021
1 parent 4213728 commit 707f7eb
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/bart/modeling_bart.py
Expand Up @@ -1291,7 +1291,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
Expand Up @@ -2501,7 +2501,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down

0 comments on commit 707f7eb

Please sign in to comment.