mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-12 16:12:48 +08:00
Merge pull request #6052 from hiyouga/hiyouga-patch-1
[trainer] fix DPO metrics Former-commit-id: 45f32916ce3e0f1d242b91bbf9dbce2c0200f82d
This commit is contained in:
commit
acd70faf17
@ -255,10 +255,10 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().item()
|
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().item()
|
||||||
metrics[f"{prefix}rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item()
|
metrics[f"{prefix}rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item()
|
||||||
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item()
|
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item()
|
||||||
metrics[f"{prefix}logps/rejected"] = policy_chosen_logps.mean().item()
|
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.mean().item()
|
||||||
metrics[f"{prefix}logps/chosen"] = policy_rejected_logps.mean().item()
|
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.mean().item()
|
||||||
metrics[f"{prefix}logits/rejected"] = policy_chosen_logits.mean().item()
|
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.mean().item()
|
||||||
metrics[f"{prefix}logits/chosen"] = policy_rejected_logits.mean().item()
|
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.mean().item()
|
||||||
if self.loss_type == "orpo":
|
if self.loss_type == "orpo":
|
||||||
metrics[f"{prefix}sft_loss"] = sft_loss.mean().item()
|
metrics[f"{prefix}sft_loss"] = sft_loss.mean().item()
|
||||||
metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).mean().item()
|
metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).mean().item()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user