diff --git a/tests/model/test_base.py b/tests/model/test_base.py index e1991b20..6431a504 100644 --- a/tests/model/test_base.py +++ b/tests/model/test_base.py @@ -73,7 +73,8 @@ def test_valuehead(): tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False, add_valuehead=True ) - ref_model = AutoModelForCausalLMWithValueHead.from_pretrained( + ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained( TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device() ) + ref_model.v_head = ref_model.v_head.to(torch.float16) compare_model(model, ref_model)