From aa7c07caf090e4bd665a3eb081b7b0794f652b98 Mon Sep 17 00:00:00 2001 From: yinpu <741456392@qq.com> Date: Tue, 21 Jan 2025 13:38:02 +0800 Subject: [PATCH] fix: avoid redundant normalization in DPO's SFT loss calculation (#6722) Former-commit-id: 0f45982bac6b65533a94054ea5f792cb0f9e5a1f --- src/llamafactory/train/dpo/trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 770b32e5..1a1d5973 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -204,7 +204,11 @@ class CustomDPOTrainer(DPOTrainer): chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) chosen_length, _ = valid_length.split(batch_size, dim=0) - return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length + + if self.loss_type in ["ipo", "orpo", "simpo"]: + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps + else: + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length @override def compute_reference_log_probs(