mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
fix: avoid redundant normalization in DPO's SFT loss calculation (#6722)
Former-commit-id: 971a8ccbdacf130763d40c7ef82a711b2fc1292f
This commit is contained in:
parent
db9b977e4f
commit
a8fae3869d
@ -204,6 +204,10 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||
chosen_length, _ = valid_length.split(batch_size, dim=0)
|
||||
|
||||
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps
|
||||
else:
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
|
||||
|
||||
@override
|
||||
|
Loading…
x
Reference in New Issue
Block a user