This commit is contained in:
fzc8578
2025-01-06 19:32:39 +08:00
parent ab87bd6b13
commit 785cc70ff2
4 changed files with 15 additions and 7 deletions

View File

@@ -24,6 +24,7 @@ import numpy as np
import torch
from transformers import Seq2SeqTrainer
from typing_extensions import override
import copy
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
@@ -122,7 +123,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
labels = inputs.pop("labels", None)
else:
labels = inputs.get("labels")
loss, generated_tokens, _ = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
)