[feat] support ktransformers for dpo (#9621)

Co-authored-by: poryfly <porykid@gmail.com>
This commit is contained in:
mrhaoxx
2025-12-18 21:26:25 +08:00
committed by GitHub
parent 964569751f
commit a769fb94b9
2 changed files with 81 additions and 1 deletions

View File

@@ -24,7 +24,6 @@ from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push, create_ref_model
from .trainer import CustomDPOTrainer
if TYPE_CHECKING:
@@ -62,6 +61,16 @@ def run_dpo(
ref_model = create_ref_model(model_args, finetuning_args)
else:
ref_model = None
if model_args.use_kt:
from ktransformers.util.globals import GLOBAL_CONFIG
GLOBAL_CONFIG._config["mod"] = "sft"
from .ktrainer import CustomDPOTrainer
else:
from .trainer import CustomDPOTrainer
# Initialize our Trainer
trainer = CustomDPOTrainer(