[model] Update ernie_vl to adapt new version (#9665)

This commit is contained in:
Xunpeng Xiao
2025-12-26 19:57:49 +08:00
committed by GitHub
parent a882e2d5fc
commit 3c17f2722c
5 changed files with 24 additions and 16 deletions

View File

@@ -84,9 +84,7 @@ def load_reference_model(
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
model_path, torch_dtype=torch.float16, device_map="auto"
)
if not is_trainable:
model.v_head = model.v_head.to(torch.float16)
return model
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")