[trainer] fix llama3.2 vision kto train (#6904)

This commit is contained in:
marko1616
2025-02-12 19:09:14 +08:00
committed by GitHub
parent 2f8b6847f5
commit b7fd1e9c00
2 changed files with 11 additions and 0 deletions

View File

@@ -156,6 +156,15 @@ class CustomKTOTrainer(KTOTrainer):
if "image_grid_thw" in batch:
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
if "aspect_ratio_ids" in batch:
model_inputs["aspect_ratio_ids"] = batch["aspect_ratio_ids"]
if "aspect_ratio_mask" in batch:
model_inputs["aspect_ratio_mask"] = batch["aspect_ratio_mask"]
if f"{prefix}cross_attention_mask" in batch:
model_inputs["cross_attention_mask"] = batch[f"{prefix}cross_attention_mask"]
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"])
return logits, logps, logps / valid_length