Skip to content

Commit

Permalink
Fix LongformerModel hidden states (huggingface#15537)
Browse files Browse the repository at this point in the history
* add undo padding

* fix

* fix tuple issue

* make style and quality

* move unpad logic to LongformerEncoder + unpad attentions + update tests

* move unpad logic to TFLongformerEncoder

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
2 people authored and Steven committed Feb 18, 2022
1 parent 80e9655 commit f71d32b
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 23 deletions.
17 changes: 12 additions & 5 deletions src/transformers/models/longformer/modeling_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,6 +1246,7 @@ def forward(
hidden_states,
attention_mask=None,
head_mask=None,
padding_len=0,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
Expand Down Expand Up @@ -1308,6 +1309,16 @@ def custom_forward(*inputs):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

# undo padding
if padding_len > 0:
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
hidden_states = hidden_states[:, :-padding_len]
if output_hidden_states:
all_hidden_states = tuple([state[:, :-padding_len] for state in all_hidden_states])

if output_attentions:
all_attentions = tuple([state[:, :, :-padding_len, :] for state in all_attentions])

if not return_dict:
return tuple(
v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None
Expand Down Expand Up @@ -1697,18 +1708,14 @@ def forward(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
padding_len=padding_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

# undo padding
if padding_len > 0:
# unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1)
sequence_output = sequence_output[:, :-padding_len]

if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]

Expand Down
17 changes: 11 additions & 6 deletions src/transformers/models/longformer/modeling_tf_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,13 +1587,23 @@ def call(
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)

# bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)))
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)),)

# Add last layer
if output_hidden_states:
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
all_hidden_states = all_hidden_states + (hidden_states_to_add,)

# undo padding
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
hidden_states = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
if output_attentions:
all_attentions = (
tuple([state[:, :, :-padding_len, :] for state in all_attentions])
if padding_len > 0
else all_attentions
)

if not return_dict:
return tuple(
v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None
Expand Down Expand Up @@ -1763,11 +1773,6 @@ def call(
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

# undo padding
if padding_len > 0:
# unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1)
sequence_output = sequence_output[:, :-padding_len]

if not inputs["return_dict"]:
return (
sequence_output,
Expand Down
6 changes: 0 additions & 6 deletions tests/test_modeling_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,6 @@ def __init__(
# is x + self.attention_window + 1, where x is the number of tokens with global attention)
self.key_length = self.attention_window + 2

# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
# the `test_attention_outputs` and `test_hidden_states_output` tests
self.encoder_seq_length = (
self.seq_length + (self.attention_window - self.seq_length % self.attention_window) % self.attention_window
)

def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

Expand Down
6 changes: 0 additions & 6 deletions tests/test_modeling_tf_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,6 @@ def __init__(
# because its local attention only attends to `self.attention_window` and one before and one after
self.key_length = self.attention_window + 2

# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
# the `test_attention_outputs` and `test_hidden_states_output` tests
self.encoder_seq_length = (
self.seq_length + (self.attention_window - self.seq_length % self.attention_window) % self.attention_window
)

def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

Expand Down

0 comments on commit f71d32b

Please sign in to comment.