mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
fix test
Former-commit-id: a0a23f79d2d94d68e3bf1e90b95beff817bc409c
This commit is contained in:
parent
5b2284a51d
commit
3a8b2890eb
@ -37,9 +37,9 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
|
|||||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||||
for name in state_dict_a.keys():
|
for name in state_dict_a.keys():
|
||||||
if any(key in name for key in diff_keys):
|
if any(key in name for key in diff_keys):
|
||||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is False
|
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-2, atol=1e-3) is False
|
||||||
else:
|
else:
|
||||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is True
|
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-2, atol=1e-3) is True
|
||||||
|
|
||||||
|
|
||||||
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
|
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user