diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 93cd1515..72953d7a 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -285,6 +285,8 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): batch["kl_input_ids"] = kl_batch["input_ids"] batch["kl_attention_mask"] = kl_batch["attention_mask"] batch["kl_labels"] = kl_batch["labels"] + if "cross_attention_mask" in kl_batch: # for mllama inputs. + batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"] if "token_type_ids" in kl_batch: batch["kl_token_type_ids"] = kl_batch["token_type_ids"] diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index bde819b1..65c72449 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -156,6 +156,15 @@ class CustomKTOTrainer(KTOTrainer): if "image_grid_thw" in batch: model_inputs["image_grid_thw"] = batch["image_grid_thw"] + if "aspect_ratio_ids" in batch: + model_inputs["aspect_ratio_ids"] = batch["aspect_ratio_ids"] + + if "aspect_ratio_mask" in batch: + model_inputs["aspect_ratio_mask"] = batch["aspect_ratio_mask"] + + if f"{prefix}cross_attention_mask" in batch: + model_inputs["cross_attention_mask"] = batch[f"{prefix}cross_attention_mask"] + logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"]) return logits, logps, logps / valid_length