[train] fix MPO re-weight (#9405)

This commit is contained in:
Kingsley 2025-11-04 21:10:41 +08:00 committed by GitHub
parent 14abb75126
commit 56f45e826f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -203,7 +203,7 @@ class CustomDPOTrainer(DPOTrainer):
bco_losses = self.bco_loss( bco_losses = self.bco_loss(
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
) )
losses += bco_losses * self.bco_gemma losses = (losses + bco_losses * self.bco_gemma) / (1.0 + self.bco_gemma) # re-weight W_p and W_q
return losses, chosen_rewards, rejected_rewards return losses, chosen_rewards, rejected_rewards
@ -284,9 +284,6 @@ class CustomDPOTrainer(DPOTrainer):
sft_loss = -policy_chosen_logps_avg sft_loss = -policy_chosen_logps_avg
if self.ftx_gamma > 1e-6: if self.ftx_gamma > 1e-6:
losses += self.ftx_gamma * sft_loss losses += self.ftx_gamma * sft_loss
if self.bco_gemma > 1e-6:
# re-weigthing for MPO
losses /= self.ftx_gamma + self.bco_gemma + 1.0
prefix = "eval_" if train_eval == "eval" else "" prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item() metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()