diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 770b32e5..1a1d5973 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -204,7 +204,11 @@ class CustomDPOTrainer(DPOTrainer): chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) chosen_length, _ = valid_length.split(batch_size, dim=0) - return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length + + if self.loss_type in ["ipo", "orpo", "simpo"]: + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps + else: + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length @override def compute_reference_log_probs(