mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
fix: avoid redundant normalization in DPO's SFT loss calculation (#6722)
Former-commit-id: 0f45982bac6b65533a94054ea5f792cb0f9e5a1f
This commit is contained in:
parent
324f07613a
commit
aa7c07caf0
@ -204,7 +204,11 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||
chosen_length, _ = valid_length.split(batch_size, dim=0)
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
|
||||
|
||||
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps
|
||||
else:
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
|
||||
|
||||
@override
|
||||
def compute_reference_log_probs(
|
||||
|
Loading…
x
Reference in New Issue
Block a user