diff --git a/tests/models/imagegpt/test_modeling_imagegpt.py b/tests/models/imagegpt/test_modeling_imagegpt.py index 528532d4cd813..24d41348e1a4d 100644 --- a/tests/models/imagegpt/test_modeling_imagegpt.py +++ b/tests/models/imagegpt/test_modeling_imagegpt.py @@ -538,7 +538,8 @@ def test_inference_causal_lm_head(self): inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) # forward pass - outputs = model(**inputs) + with torch.no_grad(): + outputs = model(**inputs) # verify the logits expected_shape = torch.Size((1, 1024, 512))