fix ORPO loss

Former-commit-id: 5544ddde9087f00f9e20b78d0079f20c2f5d1604
This commit is contained in:
hiyouga 2024-04-01 14:42:41 +08:00
parent 52d402e2a9
commit be0a807e8c

View File

@ -99,7 +99,7 @@ class CustomORPOTrainer(DPOTrainer):
"""
metrics = {}
chosen_logps, rejected_logps, chosen_logits, rejected_logits = self.concatenated_forward(model, batch)
sft_loss = chosen_logps
sft_loss = -chosen_logps
odds_ratio_loss = self.odds_ratio_loss(chosen_logps, rejected_logps)
batch_loss = (sft_loss + self.beta * odds_ratio_loss).mean()