diff --git a/tests/xglm/test_modeling_xglm.py b/tests/xglm/test_modeling_xglm.py index 456c4eaf10735..1f80165a84cf8 100644 --- a/tests/xglm/test_modeling_xglm.py +++ b/tests/xglm/test_modeling_xglm.py @@ -418,15 +418,14 @@ def test_lm_generate_xglm_with_gradient_checkpointing(self): def test_xglm_sample(self): tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M") model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") - model.to(torch_device) torch.manual_seed(0) tokenized = tokenizer("Today is a nice day and", return_tensors="pt") - input_ids = tokenized.input_ids.to(torch_device) + input_ids = tokenized.input_ids output_ids = model.generate(input_ids, do_sample=True, num_beams=1) output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) - EXPECTED_OUTPUT_STR = "Today is a nice day and I am happy to show you all about a recent project for my" + EXPECTED_OUTPUT_STR = "Today is a nice day and the sun is shining. A nice day with warm rainy" self.assertEqual(output_str, EXPECTED_OUTPUT_STR) @slow