From 56f45e826f828e44fcdca6a1a5a854d4b71f6ec7 Mon Sep 17 00:00:00 2001 From: Kingsley Date: Tue, 4 Nov 2025 21:10:41 +0800 Subject: [PATCH] [train] fix MPO re-weight (#9405) --- src/llamafactory/train/dpo/trainer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index da593d8d..c0ebc301 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -203,7 +203,7 @@ class CustomDPOTrainer(DPOTrainer): bco_losses = self.bco_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps ) - losses += bco_losses * self.bco_gemma + losses = (losses + bco_losses * self.bco_gemma) / (1.0 + self.bco_gemma) # re-weight W_p and W_q return losses, chosen_rewards, rejected_rewards @@ -284,9 +284,6 @@ class CustomDPOTrainer(DPOTrainer): sft_loss = -policy_chosen_logps_avg if self.ftx_gamma > 1e-6: losses += self.ftx_gamma * sft_loss - if self.bco_gemma > 1e-6: - # re-weigthing for MPO - losses /= self.ftx_gamma + self.bco_gemma + 1.0 prefix = "eval_" if train_eval == "eval" else "" metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()