fix ORPO loss

This commit is contained in:
hiyouga
2024-04-01 14:42:41 +08:00
parent 5b9b40403d
commit 816d714146

View File

@@ -99,7 +99,7 @@ class CustomORPOTrainer(DPOTrainer):
"""
metrics = {}
chosen_logps, rejected_logps, chosen_logits, rejected_logits = self.concatenated_forward(model, batch)
sft_loss = chosen_logps
sft_loss = -chosen_logps
odds_ratio_loss = self.odds_ratio_loss(chosen_logps, rejected_logps)
batch_loss = (sft_loss + self.beta * odds_ratio_loss).mean()