From 089e4d9e96d00bbab12bcc769c990f5e59d8e95e Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Sat, 16 Nov 2024 16:11:16 +0800 Subject: [PATCH] fix #6050 Former-commit-id: 028ea3d9b4fa4ab74a969ac80e61a449d6c15e74 --- src/llamafactory/train/dpo/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index fdc41dd7..7e76dee2 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -255,10 +255,10 @@ class CustomDPOTrainer(DPOTrainer): metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().item() metrics[f"{prefix}rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item() metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item() - metrics[f"{prefix}logps/rejected"] = policy_chosen_logps.mean().item() - metrics[f"{prefix}logps/chosen"] = policy_rejected_logps.mean().item() - metrics[f"{prefix}logits/rejected"] = policy_chosen_logits.mean().item() - metrics[f"{prefix}logits/chosen"] = policy_rejected_logits.mean().item() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.mean().item() + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.mean().item() + metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.mean().item() + metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.mean().item() if self.loss_type == "orpo": metrics[f"{prefix}sft_loss"] = sft_loss.mean().item() metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).mean().item()