From 4e29e06bbbdc2e2790a41b18bc1baf136e3859e0 Mon Sep 17 00:00:00 2001 From: Partho Date: Tue, 11 Oct 2022 00:24:54 +0530 Subject: [PATCH] wrap forward passes with torch.no_grad() (#19438) --- tests/models/roformer/test_modeling_roformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/roformer/test_modeling_roformer.py b/tests/models/roformer/test_modeling_roformer.py index b1d7f3d8a67c3f..dadb0d8e747b6b 100644 --- a/tests/models/roformer/test_modeling_roformer.py +++ b/tests/models/roformer/test_modeling_roformer.py @@ -457,7 +457,8 @@ class RoFormerModelIntegrationTest(unittest.TestCase): def test_inference_masked_lm(self): model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base") input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) - output = model(input_ids)[0] + with torch.no_grad(): + output = model(input_ids)[0] # TODO Replace vocab size vocab_size = 50000