Skip to content

Commit

Permalink
wrap forward passes with torch.no_grad() (huggingface#19439)
Browse files Browse the repository at this point in the history
  • Loading branch information
daspartho authored and amyeroberts committed Oct 18, 2022
1 parent 75e9e5a commit a70c81b
Showing 1 changed file with 36 additions and 32 deletions.
68 changes: 36 additions & 32 deletions tests/models/visual_bert/test_modeling_visual_bert.py
Expand Up @@ -568,14 +568,15 @@ def test_inference_vqa_coco_pre(self):
attention_mask = torch.tensor([1] * 6).reshape(1, -1)
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)

output = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
visual_embeds=visual_embeds,
visual_attention_mask=visual_attention_mask,
visual_token_type_ids=visual_token_type_ids,
)
with torch.no_grad():
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
visual_embeds=visual_embeds,
visual_attention_mask=visual_attention_mask,
visual_token_type_ids=visual_token_type_ids,
)

vocab_size = 30522

Expand Down Expand Up @@ -606,14 +607,15 @@ def test_inference_vqa(self):
attention_mask = torch.tensor([1] * 6).reshape(1, -1)
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)

output = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
visual_embeds=visual_embeds,
visual_attention_mask=visual_attention_mask,
visual_token_type_ids=visual_token_type_ids,
)
with torch.no_grad():
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
visual_embeds=visual_embeds,
visual_attention_mask=visual_attention_mask,
visual_token_type_ids=visual_token_type_ids,
)

# vocab_size = 30522

Expand All @@ -637,14 +639,15 @@ def test_inference_nlvr(self):
attention_mask = torch.tensor([1] * 6).reshape(1, -1)
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)

output = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
visual_embeds=visual_embeds,
visual_attention_mask=visual_attention_mask,
visual_token_type_ids=visual_token_type_ids,
)
with torch.no_grad():
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
visual_embeds=visual_embeds,
visual_attention_mask=visual_attention_mask,
visual_token_type_ids=visual_token_type_ids,
)

# vocab_size = 30522

Expand All @@ -667,14 +670,15 @@ def test_inference_vcr(self):
visual_token_type_ids = torch.ones(size=(1, 4, 10), dtype=torch.long)
visual_attention_mask = torch.ones_like(visual_token_type_ids)

output = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
visual_embeds=visual_embeds,
visual_attention_mask=visual_attention_mask,
visual_token_type_ids=visual_token_type_ids,
)
with torch.no_grad():
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
visual_embeds=visual_embeds,
visual_attention_mask=visual_attention_mask,
visual_token_type_ids=visual_token_type_ids,
)

# vocab_size = 30522

Expand Down

0 comments on commit a70c81b

Please sign in to comment.