From a8fae3869d4f2b0cd44b6ea2985189b414367def Mon Sep 17 00:00:00 2001 From: yinpu <741456392@qq.com> Date: Tue, 21 Jan 2025 13:38:02 +0800 Subject: [PATCH] fix: avoid redundant normalization in DPO's SFT loss calculation (#6722) Former-commit-id: 971a8ccbdacf130763d40c7ef82a711b2fc1292f --- src/llamafactory/train/dpo/trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 770b32e5..1a1d5973 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -204,7 +204,11 @@ class CustomDPOTrainer(DPOTrainer): chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) chosen_length, _ = valid_length.split(batch_size, dim=0) - return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length + + if self.loss_type in ["ipo", "orpo", "simpo"]: + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps + else: + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length @override def compute_reference_log_probs(