From bae934dea318bb0352a09642c32f840eafdd4f16 Mon Sep 17 00:00:00 2001 From: marko1616 <45327989+marko1616@users.noreply.github.com> Date: Wed, 12 Feb 2025 19:09:14 +0800 Subject: [PATCH] [trainer] fix llama3.2 vision kto train (#6904) Former-commit-id: b7fd1e9c00c77a4c2a0f2f347767d22bd47213f1 --- src/llamafactory/data/collator.py | 2 ++ src/llamafactory/train/kto/trainer.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 93cd1515..72953d7a 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -285,6 +285,8 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): batch["kl_input_ids"] = kl_batch["input_ids"] batch["kl_attention_mask"] = kl_batch["attention_mask"] batch["kl_labels"] = kl_batch["labels"] + if "cross_attention_mask" in kl_batch: # for mllama inputs. + batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"] if "token_type_ids" in kl_batch: batch["kl_token_type_ids"] = kl_batch["token_type_ids"] diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index bde819b1..65c72449 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -156,6 +156,15 @@ class CustomKTOTrainer(KTOTrainer): if "image_grid_thw" in batch: model_inputs["image_grid_thw"] = batch["image_grid_thw"] + if "aspect_ratio_ids" in batch: + model_inputs["aspect_ratio_ids"] = batch["aspect_ratio_ids"] + + if "aspect_ratio_mask" in batch: + model_inputs["aspect_ratio_mask"] = batch["aspect_ratio_mask"] + + if f"{prefix}cross_attention_mask" in batch: + model_inputs["cross_attention_mask"] = batch[f"{prefix}cross_attention_mask"] + logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"]) return logits, logps, logps / valid_length