mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
fix ORPO loss
Former-commit-id: 5544ddde9087f00f9e20b78d0079f20c2f5d1604
This commit is contained in:
parent
52d402e2a9
commit
be0a807e8c
@ -99,7 +99,7 @@ class CustomORPOTrainer(DPOTrainer):
|
||||
"""
|
||||
metrics = {}
|
||||
chosen_logps, rejected_logps, chosen_logits, rejected_logits = self.concatenated_forward(model, batch)
|
||||
sft_loss = chosen_logps
|
||||
sft_loss = -chosen_logps
|
||||
odds_ratio_loss = self.odds_ratio_loss(chosen_logps, rejected_logps)
|
||||
batch_loss = (sft_loss + self.beta * odds_ratio_loss).mean()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user