Skip to content

Commit

Permalink
Fix bug to make feature extractor resize only shorter edge
Browse files Browse the repository at this point in the history
  • Loading branch information
Niels Rogge authored and Niels Rogge committed Jul 8, 2022
1 parent 3285f53 commit b7b8b47
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 30 deletions.
16 changes: 5 additions & 11 deletions src/transformers/models/videomae/convert_videomae_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,6 @@ def convert_videomae_checkpoint(checkpoint_path, pytorch_dump_folder_path, model
outputs = model(**inputs)
logits = outputs.logits

print("Shape of logits:", logits.shape)

model_names = [
# Kinetics-400 checkpoints (short = pretrained only for 800 epochs instead of 1600)
"videomae-base-short",
Expand All @@ -206,19 +204,15 @@ def convert_videomae_checkpoint(checkpoint_path, pytorch_dump_folder_path, model

if model_name == "videomae-base-short":
expected_shape = torch.Size([1, 1408, 1536])
expected_slice = torch.tensor(
[[-0.4798, -0.3191, -0.2558], [-0.3396, -0.2823, -0.1581], [0.4327, 0.4635, 0.4745]]
)
expected_slice = torch.tensor([[0.7994, 0.9612, 0.8508], [0.7401, 0.8958, 0.8302], [0.5862, 0.7468, 0.7325]])
# we verified the loss both for normalized and unnormalized targets for this one
expected_loss = (
torch.tensor([0.5379046201705933]) if config.norm_pix_loss else torch.tensor([0.593469500541687])
)
expected_loss = torch.tensor([0.5142]) if config.norm_pix_loss else torch.tensor([0.6469])
elif model_name == "videomae-base-finetuned-kinetics":
expected_shape = torch.Size([1, 400])
expected_slice = torch.tensor([0.7666, -0.2265, -0.5551])
expected_slice = torch.tensor([0.3669, -0.0688, -0.2421])
elif model_name == "videomae-base-finetuned-ssv2":
expected_shape = torch.Size([1, 74])
expected_slice = torch.tensor([-0.1354, -0.4494, -0.4979])
expected_shape = torch.Size([1, 174])
expected_slice = torch.tensor([-0.0537, -0.1539, -0.3266])

# verify logits
assert logits.shape == expected_shape
Expand Down
19 changes: 10 additions & 9 deletions src/transformers/models/videomae/feature_extraction_videomae.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ class VideoMAEFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMix
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
Whether to resize the shorter edge of the input to a certain `size`.
size (`int` or `Tuple(int)`, *optional*, defaults to 224):
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
set to `True`.
Resize the shorter edge of the input to the given size. If a tuple is provided, it should be (width,
height). If only an integer is provided, then the input will be resized to (size, size). Only has an effect
if `do_resize` is set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
`PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD

def resize_video(self, video, size, resample="bilinear"):
return [self.resize(frame, size, resample) for frame in video]
return [self.resize(frame, size, resample, default_to_square=False) for frame in video]

def crop_video(self, video, size):
return [self.center_crop(frame, size) for frame in video]
Expand All @@ -105,10 +105,11 @@ def __call__(
</Tip>
Args:
videos (`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[PIL.Image.Image]]`, `List[List[np.ndarrray]]`):
The video or batch of videos to be prepared. Each video should be a list of frames, which can be either
PIL images or NumPy arrays. In case of a NumPy array, each frame should be of shape (H, W, C), where H
and W are frame height and width, and C is a number of channels.
videos (`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, `List[List[PIL.Image.Image]]`, `List[List[np.ndarrray]]`,:
`List[List[torch.Tensor]]`): The video or batch of videos to be prepared. Each video should be a list
of frames, which can be either PIL images or NumPy arrays. In case of NumPy arrays/PyTorch tensors,
each frame should be of shape (H, W, C), where H and W are frame height and width, and C is a number of
channels.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
Expand Down
10 changes: 7 additions & 3 deletions src/transformers/models/videomae/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# import torch
import numpy as np
import torch

from transformers import VideoMAEFeatureExtractor

Expand All @@ -10,8 +10,12 @@

video = [np.random.rand(512, 640, 3), np.random.rand(312, 200, 3)]

video = np.random.rand(16, 360, 640, 3)
video = [video[i] for i in range(video.shape[0])]
video = [np.random.rand(3, 512, 640), np.random.rand(3, 312, 200)]

video = [torch.randn(3, 512, 640), torch.rand(3, 312, 200)]

# video = np.random.rand(16, 360, 640, 3)
# video = [video[i] for i in range(video.shape[0])]

encoding = feature_extractor(video, return_tensors="pt")

Expand Down
21 changes: 14 additions & 7 deletions tests/models/videomae/test_modeling_videomae.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def test_inference_for_video_classification(self):
expected_shape = torch.Size((1, 400))
self.assertEqual(outputs.logits.shape, expected_shape)

expected_slice = torch.tensor([0.7666, -0.2265, -0.5551]).to(torch_device)
expected_slice = torch.tensor([0.3669, -0.0688, -0.2421]).to(torch_device)

self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))

Expand All @@ -395,14 +395,21 @@ def test_inference_for_pretraining(self):

# verify the logits
expected_shape = torch.Size([1, 1408, 1536])
expected_slice = torch.tensor(
[[-0.4798, -0.3191, -0.2558], [-0.3396, -0.2823, -0.1581], [0.4327, 0.4635, 0.4745]]
)
expected_slice = torch.tensor([[0.7994, 0.9612, 0.8508], [0.7401, 0.8958, 0.8302], [0.5862, 0.7468, 0.7325]])
self.assertEqual(outputs.logits.shape, expected_shape)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice, atol=1e-4))

# verify the loss
expected_loss = (
torch.tensor([0.5379046201705933]) if model.config.norm_pix_loss else torch.tensor([0.593469500541687])
# verify the loss (`config.norm_pix_loss` = `True`)
expected_loss = torch.tensor([0.5142])
self.assertTrue(torch.allclose(outputs.loss, expected_loss, atol=1e-4))

# verify the loss (`config.norm_pix_loss` = `False`)
model = VideoMAEForPreTraining.from_pretrained("nielsr/videomae-base-short", norm_pix_loss=False).to(
torch_device
)

with torch.no_grad():
outputs = model(**inputs)

expected_loss = torch.tensor(torch.tensor([0.6469]))
self.assertTrue(torch.allclose(outputs.loss, expected_loss, atol=1e-4))

0 comments on commit b7b8b47

Please sign in to comment.