diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 9c2e5aecfb7780..68ab7a9fcb6d2f 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -630,7 +630,7 @@ class VideoMAEDecoder(nn.Module): def __init__(self, config, num_patches): super().__init__() - decoder_num_labels = 3 * config.tubelet_size * config.patch_size**2 + decoder_num_labels = config.num_channels * config.tubelet_size * config.patch_size**2 decoder_config = deepcopy(config) decoder_config.hidden_size = config.decoder_hidden_size @@ -768,20 +768,26 @@ def forward( ) sequence_output = outputs[0] - sequence_output = self.encoder_to_decoder(sequence_output) # [B, N_vis, C_d] + sequence_output = self.encoder_to_decoder( + sequence_output + ) # [batch_size, num_visible_patches, decoder_hidden_size] batch_size, seq_len, num_channels = sequence_output.shape - # we don't unshuffle the correct visible token order, - # but shuffle the pos embedding accordingly. - # TODO check for bool_masked_pos to be available + # we don't unshuffle the correct visible token order, but shuffle the position embeddings accordingly. + if bool_masked_pos is None: + raise ValueError("One must provided a boolean mask ") expanded_position_embeddings = self.position_embeddings.expand(batch_size, -1, -1).type_as(pixel_values) expanded_position_embeddings = expanded_position_embeddings.to(pixel_values.device).clone().detach() - pos_emd_vis = expanded_position_embeddings[~bool_masked_pos].reshape(batch_size, -1, num_channels) - pos_emd_mask = expanded_position_embeddings[bool_masked_pos].reshape(batch_size, -1, num_channels) + pos_emb_visible = expanded_position_embeddings[~bool_masked_pos].reshape(batch_size, -1, num_channels) + pos_emb_mask = expanded_position_embeddings[bool_masked_pos].reshape(batch_size, -1, num_channels) - x_full = torch.cat([sequence_output + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1) # [B, N, C_d] + x_full = torch.cat( + [sequence_output + pos_emb_visible, self.mask_token + pos_emb_mask], dim=1 + ) # [batch_size, num_patches, decoder_hidden_size] - decoder_outputs = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16] + decoder_outputs = self.decoder( + x_full, pos_emb_mask.shape[1] + ) # [batch_size, num_masked_patches, num_channels * patch_size * patch_size] logits = decoder_outputs.logits # TODO verify loss computation