mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-05 21:22:50 +08:00
fix ORPO loss
Former-commit-id: 816d71414617590f95de89a49f38358e597ed121
This commit is contained in:
parent
69e1d39832
commit
bd52e2b404
@ -99,7 +99,7 @@ class CustomORPOTrainer(DPOTrainer):
|
|||||||
"""
|
"""
|
||||||
metrics = {}
|
metrics = {}
|
||||||
chosen_logps, rejected_logps, chosen_logits, rejected_logits = self.concatenated_forward(model, batch)
|
chosen_logps, rejected_logps, chosen_logits, rejected_logits = self.concatenated_forward(model, batch)
|
||||||
sft_loss = chosen_logps
|
sft_loss = -chosen_logps
|
||||||
odds_ratio_loss = self.odds_ratio_loss(chosen_logps, rejected_logps)
|
odds_ratio_loss = self.odds_ratio_loss(chosen_logps, rejected_logps)
|
||||||
batch_loss = (sft_loss + self.beta * odds_ratio_loss).mean()
|
batch_loss = (sft_loss + self.beta * odds_ratio_loss).mean()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user