mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
fix: avoid redundant normalization in DPO's SFT loss calculation (#6722)
Former-commit-id: 971a8ccbdacf130763d40c7ef82a711b2fc1292f
This commit is contained in:
parent
db9b977e4f
commit
a8fae3869d
@ -204,7 +204,11 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||||
chosen_length, _ = valid_length.split(batch_size, dim=0)
|
chosen_length, _ = valid_length.split(batch_size, dim=0)
|
||||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
|
|
||||||
|
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||||
|
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps
|
||||||
|
else:
|
||||||
|
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def compute_reference_log_probs(
|
def compute_reference_log_probs(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user