fix: avoid redundant normalization in DPO's SFT loss calculation (#6722)

Former-commit-id: 0f45982bac6b65533a94054ea5f792cb0f9e5a1f
This commit is contained in:
yinpu 2025-01-21 13:38:02 +08:00 committed by GitHub
parent 324f07613a
commit aa7c07caf0

View File

@ -204,6 +204,10 @@ class CustomDPOTrainer(DPOTrainer):
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
chosen_length, _ = valid_length.split(batch_size, dim=0)
if self.loss_type in ["ipo", "orpo", "simpo"]:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps
else:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
@override