fix dpo metrics

This commit is contained in:
hiyouga
2024-11-02 19:22:11 +08:00
parent 07e5088851
commit 4270f7dfb9
7 changed files with 143 additions and 58 deletions

View File

@@ -87,7 +87,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
loss /= self.args.gradient_accumulation_steps # other model should not scale the loss
# other model should not scale the loss
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss