[train] fix denominator of ga in ksft loss (#9409)

This commit is contained in:
Peilin Li 2025-11-05 20:53:23 +08:00 committed by GitHub
parent 8edd2622ce
commit bd30c0003b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -83,6 +83,7 @@ def run_sft(
**dataset_module,
**metric_module,
)
trainer.model_accepts_loss_kwargs = False
# Training
if training_args.do_train: