mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[trainer] fix llama3.2 vision kto train (#6904)
Former-commit-id: b7fd1e9c00c77a4c2a0f2f347767d22bd47213f1
This commit is contained in:
parent
2e2f6bea07
commit
bae934dea3
@ -285,6 +285,8 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
batch["kl_input_ids"] = kl_batch["input_ids"]
|
||||
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
||||
batch["kl_labels"] = kl_batch["labels"]
|
||||
if "cross_attention_mask" in kl_batch: # for mllama inputs.
|
||||
batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"]
|
||||
if "token_type_ids" in kl_batch:
|
||||
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
|
||||
|
||||
|
@ -156,6 +156,15 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
if "image_grid_thw" in batch:
|
||||
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
|
||||
|
||||
if "aspect_ratio_ids" in batch:
|
||||
model_inputs["aspect_ratio_ids"] = batch["aspect_ratio_ids"]
|
||||
|
||||
if "aspect_ratio_mask" in batch:
|
||||
model_inputs["aspect_ratio_mask"] = batch["aspect_ratio_mask"]
|
||||
|
||||
if f"{prefix}cross_attention_mask" in batch:
|
||||
model_inputs["cross_attention_mask"] = batch[f"{prefix}cross_attention_mask"]
|
||||
|
||||
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"])
|
||||
return logits, logps, logps / valid_length
|
||||
|
Loading…
x
Reference in New Issue
Block a user