mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-06 02:42:15 +08:00
[train] fix MPO re-weight (#9405)
This commit is contained in:
parent
14abb75126
commit
56f45e826f
@ -203,7 +203,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
bco_losses = self.bco_loss(
|
bco_losses = self.bco_loss(
|
||||||
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
|
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
|
||||||
)
|
)
|
||||||
losses += bco_losses * self.bco_gemma
|
losses = (losses + bco_losses * self.bco_gemma) / (1.0 + self.bco_gemma) # re-weight W_p and W_q
|
||||||
|
|
||||||
return losses, chosen_rewards, rejected_rewards
|
return losses, chosen_rewards, rejected_rewards
|
||||||
|
|
||||||
@ -284,9 +284,6 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
sft_loss = -policy_chosen_logps_avg
|
sft_loss = -policy_chosen_logps_avg
|
||||||
if self.ftx_gamma > 1e-6:
|
if self.ftx_gamma > 1e-6:
|
||||||
losses += self.ftx_gamma * sft_loss
|
losses += self.ftx_gamma * sft_loss
|
||||||
if self.bco_gemma > 1e-6:
|
|
||||||
# re-weigthing for MPO
|
|
||||||
losses /= self.ftx_gamma + self.bco_gemma + 1.0
|
|
||||||
|
|
||||||
prefix = "eval_" if train_eval == "eval" else ""
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
|
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user