mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 02:00:36 +08:00
[model] Update ernie_vl to adapt new version (#9665)
This commit is contained in:
@@ -84,9 +84,7 @@ def load_reference_model(
|
||||
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
if not is_trainable:
|
||||
model.v_head = model.v_head.to(torch.float16)
|
||||
|
||||
|
||||
return model
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
|
||||
|
||||
Reference in New Issue
Block a user