mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 20:22:49 +08:00
fix ORPO loss
Former-commit-id: 816d71414617590f95de89a49f38358e597ed121
This commit is contained in:
parent
69e1d39832
commit
bd52e2b404
@ -99,7 +99,7 @@ class CustomORPOTrainer(DPOTrainer):
|
||||
"""
|
||||
metrics = {}
|
||||
chosen_logps, rejected_logps, chosen_logits, rejected_logits = self.concatenated_forward(model, batch)
|
||||
sft_loss = chosen_logps
|
||||
sft_loss = -chosen_logps
|
||||
odds_ratio_loss = self.odds_ratio_loss(chosen_logps, rejected_logps)
|
||||
batch_loss = (sft_loss + self.beta * odds_ratio_loss).mean()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user