mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
fix dpo metrics
This commit is contained in:
@@ -87,7 +87,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
"""
|
||||
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
||||
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
|
||||
loss /= self.args.gradient_accumulation_steps # other model should not scale the loss
|
||||
# other model should not scale the loss
|
||||
if return_outputs:
|
||||
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||
else:
|
||||
return loss / self.args.gradient_accumulation_steps
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user