This commit is contained in:
hiyouga
2024-12-30 05:55:15 +00:00
parent b55890291b
commit 6f5bb3b8e5
7 changed files with 26 additions and 11 deletions

View File

@@ -28,7 +28,7 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
@@ -50,6 +50,9 @@ class CustomKTOTrainer(KTOTrainer):
disable_dropout: bool = True,
**kwargs,
):
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
if disable_dropout:
disable_dropout_in_model(model)
if ref_model is not None: