[train] KTransformers SFT as backend engine for LLaMA-Factory (#9400)

Co-authored-by: jimmy128 <jimmy128@noreply.gitcode.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Peilin Li
2025-11-04 15:54:12 +08:00
committed by GitHub
parent 3ae15da9c0
commit 934b3084ee
37 changed files with 2006 additions and 16 deletions

View File

@@ -156,6 +156,9 @@ def _check_extra_dependencies(
finetuning_args: "FinetuningArguments",
training_args: Optional["TrainingArguments"] = None,
) -> None:
if model_args.use_kt:
check_version("ktransformers", mandatory=True)
if model_args.use_unsloth:
check_version("unsloth", mandatory=True)
@@ -282,13 +285,16 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
if finetuning_args.reward_model_type == "lora" and model_args.use_kt:
raise ValueError("KTransformers does not support lora reward model.")
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.")
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
raise ValueError("PPO only accepts wandb or tensorboard logger.")
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
@@ -350,6 +356,9 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if model_args.use_kt and is_deepspeed_zero3_enabled():
raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")