[model] update kt code (#9406)

This commit is contained in:
Yaowei Zheng
2025-11-05 15:27:22 +08:00
committed by GitHub
parent 56f45e826f
commit eaf963f67f
28 changed files with 108 additions and 68 deletions

View File

@@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import infer_optim_dtype
from ..extras.packages import is_kt_available, is_mcore_adapter_available, is_ray_available
from ..extras.packages import is_mcore_adapter_available, is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
@@ -86,12 +86,12 @@ def _training_function(config: dict[str, Any]) -> None:
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "sft":
if model_args.use_kt:
if not is_kt_available():
raise ImportError("KTransformers is not installed. Please install it with `pip install ktransformers`.")
from .ksft.workflow import run_sft as run_sft_kt
run_sft_kt(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
else:
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "ppo":