fix: avoid redundant normalization in DPO's SFT loss calculation (#6722)

Former-commit-id: 971a8ccbdacf130763d40c7ef82a711b2fc1292f
This commit is contained in:
yinpu 2025-01-21 13:38:02 +08:00 committed by GitHub
parent f34390b596
commit 5062b099f7

View File

@ -204,6 +204,10 @@ class CustomDPOTrainer(DPOTrainer):
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
chosen_length, _ = valid_length.split(batch_size, dim=0)
if self.loss_type in ["ipo", "orpo", "simpo"]:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps
else:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
@override