This commit is contained in:
fzc8578
2025-01-10 20:55:52 +08:00
parent d09032049c
commit 15bbcdf8d3
4 changed files with 10 additions and 9 deletions

View File

@@ -122,6 +122,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
labels = inputs.pop("labels", None)
else:
labels = inputs.get("labels")
loss, generated_tokens, _ = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
)