mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 12:20:37 +08:00
fix some
This commit is contained in:
@@ -122,6 +122,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
labels = inputs.pop("labels", None)
|
||||
else:
|
||||
labels = inputs.get("labels")
|
||||
|
||||
loss, generated_tokens, _ = super().prediction_step(
|
||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user