Skip to content

Commit

Permalink
Improve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Niels Rogge authored and Niels Rogge committed Jun 27, 2022
1 parent eb392fc commit 86078d9
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions src/transformers/models/videomae/modeling_videomae.py
Expand Up @@ -630,7 +630,7 @@ class VideoMAEDecoder(nn.Module):
def __init__(self, config, num_patches):
super().__init__()

decoder_num_labels = 3 * config.tubelet_size * config.patch_size**2
decoder_num_labels = config.num_channels * config.tubelet_size * config.patch_size**2

decoder_config = deepcopy(config)
decoder_config.hidden_size = config.decoder_hidden_size
Expand Down Expand Up @@ -768,20 +768,26 @@ def forward(
)

sequence_output = outputs[0]
sequence_output = self.encoder_to_decoder(sequence_output) # [B, N_vis, C_d]
sequence_output = self.encoder_to_decoder(
sequence_output
) # [batch_size, num_visible_patches, decoder_hidden_size]
batch_size, seq_len, num_channels = sequence_output.shape

# we don't unshuffle the correct visible token order,
# but shuffle the pos embedding accordingly.
# TODO check for bool_masked_pos to be available
# we don't unshuffle the correct visible token order, but shuffle the position embeddings accordingly.
if bool_masked_pos is None:
raise ValueError("One must provided a boolean mask ")
expanded_position_embeddings = self.position_embeddings.expand(batch_size, -1, -1).type_as(pixel_values)
expanded_position_embeddings = expanded_position_embeddings.to(pixel_values.device).clone().detach()
pos_emd_vis = expanded_position_embeddings[~bool_masked_pos].reshape(batch_size, -1, num_channels)
pos_emd_mask = expanded_position_embeddings[bool_masked_pos].reshape(batch_size, -1, num_channels)
pos_emb_visible = expanded_position_embeddings[~bool_masked_pos].reshape(batch_size, -1, num_channels)
pos_emb_mask = expanded_position_embeddings[bool_masked_pos].reshape(batch_size, -1, num_channels)

x_full = torch.cat([sequence_output + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1) # [B, N, C_d]
x_full = torch.cat(
[sequence_output + pos_emb_visible, self.mask_token + pos_emb_mask], dim=1
) # [batch_size, num_patches, decoder_hidden_size]

decoder_outputs = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]
decoder_outputs = self.decoder(
x_full, pos_emb_mask.shape[1]
) # [batch_size, num_masked_patches, num_channels * patch_size * patch_size]
logits = decoder_outputs.logits

# TODO verify loss computation
Expand Down

0 comments on commit 86078d9

Please sign in to comment.