Skip to content

Commit

Permalink
[XGLM] run sampling test on CPU to be deterministic (#15892)
Browse files Browse the repository at this point in the history
* run sampling test on CPU to be deterministic

* input_ids on CPU
  • Loading branch information
patil-suraj committed Mar 2, 2022
1 parent baab5e7 commit 130b987
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions tests/xglm/test_modeling_xglm.py
Expand Up @@ -418,15 +418,14 @@ def test_lm_generate_xglm_with_gradient_checkpointing(self):
def test_xglm_sample(self):
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
model.to(torch_device)

torch.manual_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="pt")
input_ids = tokenized.input_ids.to(torch_device)
input_ids = tokenized.input_ids
output_ids = model.generate(input_ids, do_sample=True, num_beams=1)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)

EXPECTED_OUTPUT_STR = "Today is a nice day and I am happy to show you all about a recent project for my"
EXPECTED_OUTPUT_STR = "Today is a nice day and the sun is shining. A nice day with warm rainy"
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)

@slow
Expand Down

0 comments on commit 130b987

Please sign in to comment.