mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-25 07:12:50 +08:00
increase tol
Former-commit-id: de43bee0b004c7e90811100474b3113590d0f130
This commit is contained in:
parent
f25b8626bf
commit
2d2c78d66c
@ -59,7 +59,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
|
|||||||
state_dict_b = model_b.state_dict()
|
state_dict_b = model_b.state_dict()
|
||||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||||
for name in state_dict_a.keys():
|
for name in state_dict_a.keys():
|
||||||
assert torch.allclose(state_dict_a[name], state_dict_b[name])
|
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
def test_pissa_init():
|
def test_pissa_init():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user