Former-commit-id: 38b6b0f52e
This commit is contained in:
hiyouga
2024-06-16 01:06:41 +08:00
parent 96b82ccd4d
commit c0c6b8075a
22 changed files with 27 additions and 25 deletions

View File

@@ -41,7 +41,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
state_dict_b = model_b.state_dict()
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
for name in state_dict_a.keys():
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
assert torch.allclose(state_dict_a[name], state_dict_b[name])
@pytest.fixture