Skip to content

Commit

Permalink
Update design of patch embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
Niels Rogge authored and Niels Rogge committed Jun 22, 2022
1 parent e2e6336 commit eb392fc
Showing 1 changed file with 29 additions and 28 deletions.
57 changes: 29 additions & 28 deletions src/transformers/models/videomae/modeling_videomae.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,7 @@ class VideoMAEEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()

self.patch_embeddings = VideoMAEPatchEmbeddings(
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=config.hidden_size,
num_frames=config.num_frames,
tubelet_size=config.tubelet_size,
)
self.patch_embeddings = VideoMAEPatchEmbeddings(config)
self.num_patches = self.patch_embeddings.num_patches
# fixed sin-cos embedding
self.position_embeddings = get_sinusoid_encoding_table(self.num_patches, config.hidden_size)
Expand Down Expand Up @@ -167,7 +160,16 @@ class VideoMAEPatchEmbeddings(nn.Module):
"""

def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768, num_frames=16, tubelet_size=2):
def __init__(self, config):
image_size, patch_size, num_channels, hidden_size, num_frames, tubelet_size = (
config.image_size,
config.patch_size,
config.num_channels,
config.hidden_size,
config.num_frames,
config.tubelet_size,
)

super().__init__()
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
Expand All @@ -177,16 +179,21 @@ def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768,
num_patches = (
(image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
)
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv3d(
in_channels=num_channels,
out_channels=embed_dim,
out_channels=hidden_size,
kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
stride=(self.tubelet_size, patch_size[0], patch_size[1]),
)

def forward(self, pixel_values):
batch_size, num_frames, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
Expand Down Expand Up @@ -573,16 +580,14 @@ def forward(
```python
>>> from transformers import VideoMAEFeatureExtractor, VideoMAEModel
>>> from PIL import Image
>>> import requests
>>> import torch
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> video = torch.randn(1, 16, 3, 224, 224)
>>> feature_extractor = VideoMAEFeatureExtractor.from_pretrained("nanjing/videomae-base")
>>> model = VideoMAEModel.from_pretrained("nanjing/videomae-base")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> inputs = feature_extractor(video, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
Expand Down Expand Up @@ -738,18 +743,17 @@ def forward(
Examples:
```python
>>> from transformers import VideoMAEFeatureExtractor, VideoMAEForPreTraining
>>> from PIL import Image
>>> import requests
>>> import numpy as np
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> video = np.random.randn(16, 3, 224, 224).tolist()
>>> feature_extractor = VideoMAEFeatureExtractor.from_pretrained("nanjing/vit-mae-base")
>>> model = VideoMAEForPreTraining.from_pretrained("nanjing/vit-mae-base")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> pixel_values = feature_extractor(video, return_tensors="pt").pixel_values
>>> bool_masked_pos = ...
>>> outputs = model(**inputs)
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss = outputs.loss
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Expand Down Expand Up @@ -780,8 +784,7 @@ def forward(
decoder_outputs = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]
logits = decoder_outputs.logits

# TODO compute loss
# TODO check correct format of videos! (B, T, C, H, W)
# TODO verify loss computation
loss = None
with torch.no_grad():
# calculate the labels to be predicted
Expand Down Expand Up @@ -905,16 +908,14 @@ def forward(
```python
>>> from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification
>>> from PIL import Image
>>> import requests
>>> import numpy as np
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> video = np.random.randn(16, 3, 224, 224).tolist()
>>> feature_extractor = VideoMAEFeatureExtractor.from_pretrained("nanjing/videomae-base")
>>> model = VideoMAEForVideoClassification.from_pretrained("nanjing/videomae-base")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> inputs = feature_extractor(video, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
Expand Down

0 comments on commit eb392fc

Please sign in to comment.