mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 12:50:38 +08:00
fix #6546
This commit is contained in:
@@ -30,7 +30,7 @@ from typing_extensions import override
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -142,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
r"""
|
||||
Runs forward pass and computes the log probabilities.
|
||||
"""
|
||||
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
model_inputs = {
|
||||
"input_ids": batch[f"{prefix}input_ids"],
|
||||
"attention_mask": batch[f"{prefix}attention_mask"],
|
||||
|
||||
Reference in New Issue
Block a user