From 9815d1712c1186eb6917d829192a96ccb6242ab2 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Sat, 16 Nov 2024 16:11:16 +0800 Subject: [PATCH] fix #6050 Former-commit-id: dc828218726704ff0453a2d13535663ac6ad7833 --- src/llamafactory/train/dpo/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index fdc41dd7..7e76dee2 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -255,10 +255,10 @@ class CustomDPOTrainer(DPOTrainer): metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().item() metrics[f"{prefix}rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item() metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item() - metrics[f"{prefix}logps/rejected"] = policy_chosen_logps.mean().item() - metrics[f"{prefix}logps/chosen"] = policy_rejected_logps.mean().item() - metrics[f"{prefix}logits/rejected"] = policy_chosen_logits.mean().item() - metrics[f"{prefix}logits/chosen"] = policy_rejected_logits.mean().item() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.mean().item() + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.mean().item() + metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.mean().item() + metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.mean().item() if self.loss_type == "orpo": metrics[f"{prefix}sft_loss"] = sft_loss.mean().item() metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).mean().item()