mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 15:52:49 +08:00
fix test case
Former-commit-id: 6663057cfbdc96385d901a5dfba22cfcd7a61b23
This commit is contained in:
parent
f51b435bcf
commit
cd899734f3
@ -73,7 +73,8 @@ def test_valuehead():
|
||||
tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False, add_valuehead=True
|
||||
)
|
||||
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
)
|
||||
ref_model.v_head = ref_model.v_head.to(torch.float16)
|
||||
compare_model(model, ref_model)
|
||||
|
Loading…
x
Reference in New Issue
Block a user