fix ORPO loss

Former-commit-id: 816d71414617590f95de89a49f38358e597ed121
This commit is contained in:
hiyouga 2024-04-01 14:42:41 +08:00
parent 69e1d39832
commit bd52e2b404

View File

@ -99,7 +99,7 @@ class CustomORPOTrainer(DPOTrainer):
"""
metrics = {}
chosen_logps, rejected_logps, chosen_logits, rejected_logits = self.concatenated_forward(model, batch)
sft_loss = chosen_logps
sft_loss = -chosen_logps
odds_ratio_loss = self.odds_ratio_loss(chosen_logps, rejected_logps)
batch_loss = (sft_loss + self.beta * odds_ratio_loss).mean()