mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 04:10:36 +08:00
add some
This commit is contained in:
@@ -24,6 +24,7 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import Seq2SeqTrainer
|
||||
from typing_extensions import override
|
||||
import copy
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -122,7 +123,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
labels = inputs.pop("labels", None)
|
||||
else:
|
||||
labels = inputs.get("labels")
|
||||
|
||||
loss, generated_tokens, _ = super().prediction_step(
|
||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user