Skip to content

Commit

Permalink
Update max_diff in test_save_load_fast_init_to_base (huggingface#…
Browse files Browse the repository at this point in the history
…19849)

* Fix test_save_load_fast_init_to_base

* Fix test_save_load_fast_init_to_base

* update

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
2 people authored and amyeroberts committed Nov 1, 2022
1 parent a0c8625 commit f27cfd9
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,9 @@ class CopyClass(base_class):
model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)

for key in model_fast_init.state_dict().keys():
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
max_diff = torch.max(
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
).item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

def test_initialization(self):
Expand Down

0 comments on commit f27cfd9

Please sign in to comment.