diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index da593d8d..c0ebc301 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -203,7 +203,7 @@ class CustomDPOTrainer(DPOTrainer): bco_losses = self.bco_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps ) - losses += bco_losses * self.bco_gemma + losses = (losses + bco_losses * self.bco_gemma) / (1.0 + self.bco_gemma) # re-weight W_p and W_q return losses, chosen_rewards, rejected_rewards @@ -284,9 +284,6 @@ class CustomDPOTrainer(DPOTrainer): sft_loss = -policy_chosen_logps_avg if self.ftx_gamma > 1e-6: losses += self.ftx_gamma * sft_loss - if self.bco_gemma > 1e-6: - # re-weigthing for MPO - losses /= self.ftx_gamma + self.bco_gemma + 1.0 prefix = "eval_" if train_eval == "eval" else "" metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()