Skip to content

Commit

Permalink
Add integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
Niels Rogge authored and Niels Rogge committed Jun 20, 2022
1 parent dbcff40 commit 36e2185
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/auto/feature_extraction_auto.py
Expand Up @@ -64,8 +64,8 @@
("speech_to_text", "Speech2TextFeatureExtractor"),
("swin", "ViTFeatureExtractor"),
("van", "ConvNextFeatureExtractor"),
("vilt", "ViltFeatureExtractor"),
("videomae", "ViTFeatureExtractor"),
("vilt", "ViltFeatureExtractor"),
("vit", "ViTFeatureExtractor"),
("vit_mae", "ViTFeatureExtractor"),
("wav2vec2", "Wav2Vec2FeatureExtractor"),
Expand Down
39 changes: 30 additions & 9 deletions tests/models/videomae/test_modeling_videomae.py
Expand Up @@ -19,6 +19,9 @@
import inspect
import unittest

import numpy as np

from huggingface_hub import hf_hub_download
from transformers import VideoMAEConfig
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
Expand All @@ -42,8 +45,6 @@


if is_vision_available():
from PIL import Image

from transformers import VideoMAEFeatureExtractor


Expand Down Expand Up @@ -338,19 +339,39 @@ def check_hidden_states_output(inputs_dict, config, model_class):
check_hidden_states_output(inputs_dict, config, model_class)


# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
# We will verify our results on a video of eating spaghetti
# Frame indices used: [164 168 172 176 181 185 189 193 198 202 206 210 215 219 223 227]
def prepare_video():
file = hf_hub_download(repo_id="datasets/hf-internal-testing/spaghetti-video", filename="eating_spaghetti.npy")
video = np.load(file)
return list(video)


@require_torch
@require_vision
class VideoMAEModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return VideoMAEFeatureExtractor.from_pretrained("nanjing/videomae-base") if is_vision_available() else None
# TODO update to appropriate organization
return VideoMAEFeatureExtractor() if is_vision_available() else None

@slow
def test_inference_for_pretraining(self):
raise NotImplementedError("To do")
def test_inference_for_video_classification(self):
# TODO update to appropriate organization
model = VideoMAEForVideoClassification.from_pretrained("nielsr/videomae-base").to(torch_device)

feature_extractor = self.default_feature_extractor
video = prepare_video()
inputs = feature_extractor(video, return_tensors="pt").to(torch_device)

# forward pass
with torch.no_grad():
outputs = model(**inputs)

# verify the logits
expected_shape = torch.Size((1, 400))
self.assertEqual(outputs.logits.shape, expected_shape)

expected_slice = torch.tensor([0.7666, -0.2265, -0.5551]).to(torch_device)

self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))

0 comments on commit 36e2185

Please sign in to comment.