From bd52e2b404ea3cc6a85eb78e346708c6fd670876 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 1 Apr 2024 14:42:41 +0800 Subject: [PATCH] fix ORPO loss Former-commit-id: 816d71414617590f95de89a49f38358e597ed121 --- src/llmtuner/train/orpo/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmtuner/train/orpo/trainer.py b/src/llmtuner/train/orpo/trainer.py index 50b999f8..f5b7ff42 100644 --- a/src/llmtuner/train/orpo/trainer.py +++ b/src/llmtuner/train/orpo/trainer.py @@ -99,7 +99,7 @@ class CustomORPOTrainer(DPOTrainer): """ metrics = {} chosen_logps, rejected_logps, chosen_logits, rejected_logits = self.concatenated_forward(model, batch) - sft_loss = chosen_logps + sft_loss = -chosen_logps odds_ratio_loss = self.odds_ratio_loss(chosen_logps, rejected_logps) batch_loss = (sft_loss + self.beta * odds_ratio_loss).mean()