diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 8e67c231d01784..0fc33a043e62f0 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -64,8 +64,8 @@ ("speech_to_text", "Speech2TextFeatureExtractor"), ("swin", "ViTFeatureExtractor"), ("van", "ConvNextFeatureExtractor"), - ("vilt", "ViltFeatureExtractor"), ("videomae", "ViTFeatureExtractor"), + ("vilt", "ViltFeatureExtractor"), ("vit", "ViTFeatureExtractor"), ("vit_mae", "ViTFeatureExtractor"), ("wav2vec2", "Wav2Vec2FeatureExtractor"), diff --git a/tests/models/videomae/test_modeling_videomae.py b/tests/models/videomae/test_modeling_videomae.py index 6df20ddd72a03b..bdeedcb8e0ed9e 100644 --- a/tests/models/videomae/test_modeling_videomae.py +++ b/tests/models/videomae/test_modeling_videomae.py @@ -19,6 +19,9 @@ import inspect import unittest +import numpy as np + +from huggingface_hub import hf_hub_download from transformers import VideoMAEConfig from transformers.models.auto import get_values from transformers.testing_utils import require_torch, require_vision, slow, torch_device @@ -42,8 +45,6 @@ if is_vision_available(): - from PIL import Image - from transformers import VideoMAEFeatureExtractor @@ -338,10 +339,12 @@ def check_hidden_states_output(inputs_dict, config, model_class): check_hidden_states_output(inputs_dict, config, model_class) -# We will verify our results on an image of cute cats -def prepare_img(): - image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") - return image +# We will verify our results on a video of eating spaghetti +# Frame indices used: [164 168 172 176 181 185 189 193 198 202 206 210 215 219 223 227] +def prepare_video(): + file = hf_hub_download(repo_id="datasets/hf-internal-testing/spaghetti-video", filename="eating_spaghetti.npy") + video = np.load(file) + return list(video) @require_torch @@ -349,8 +352,26 @@ def prepare_img(): class VideoMAEModelIntegrationTest(unittest.TestCase): @cached_property def default_feature_extractor(self): - return VideoMAEFeatureExtractor.from_pretrained("nanjing/videomae-base") if is_vision_available() else None + # TODO update to appropriate organization + return VideoMAEFeatureExtractor() if is_vision_available() else None @slow - def test_inference_for_pretraining(self): - raise NotImplementedError("To do") + def test_inference_for_video_classification(self): + # TODO update to appropriate organization + model = VideoMAEForVideoClassification.from_pretrained("nielsr/videomae-base").to(torch_device) + + feature_extractor = self.default_feature_extractor + video = prepare_video() + inputs = feature_extractor(video, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # verify the logits + expected_shape = torch.Size((1, 400)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = torch.tensor([0.7666, -0.2265, -0.5551]).to(torch_device) + + self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))