[model] Update ernie_vl to adapt new version (#9665)

This commit is contained in:
Xunpeng Xiao
2025-12-26 19:57:49 +08:00
committed by GitHub
parent a882e2d5fc
commit 3c17f2722c
5 changed files with 24 additions and 16 deletions

View File

@@ -205,10 +205,6 @@ def load_model(
if not is_trainable:
model.requires_grad_(False)
for param in model.parameters():
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
param.data = param.data.to(model_args.compute_dtype)
model.eval()
else:
model.train()