mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-22 23:28:57 +08:00
[core deps] upgrade TRL to be between 0.18 and 0.24 (#9617)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -25,6 +25,7 @@ import torch
|
||||
from transformers import Trainer
|
||||
from trl import KTOTrainer
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
from trl.trainer.utils import prepare_deepspeed
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -77,6 +78,13 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self.desirable_weight = finetuning_args.kto_chosen_weight
|
||||
self.undesirable_weight = finetuning_args.kto_rejected_weight
|
||||
self.ftx_gamma = finetuning_args.pref_ftx
|
||||
# trl
|
||||
# Not all losses require a KL calculation
|
||||
self.calculate_KL = True
|
||||
if hasattr(self, "loss_type") and self.loss_type in ["apo_zero_unpaired"]:
|
||||
self.calculate_KL = False
|
||||
else:
|
||||
self.loss_type = "kto"
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
|
||||
@@ -90,7 +98,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
if not (
|
||||
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||
): # quantized models are already set on the correct device
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
self.ref_model.eval()
|
||||
|
||||
Reference in New Issue
Block a user