[train] KTransformers SFT as backend engine for LLaMA-Factory (#9400)

Co-authored-by: jimmy128 <jimmy128@noreply.gitcode.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Peilin Li
2025-11-04 15:54:12 +08:00
committed by GitHub
parent 3ae15da9c0
commit 934b3084ee
37 changed files with 2006 additions and 16 deletions

View File

@@ -1,4 +1,4 @@
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc., the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
@@ -475,9 +475,51 @@ class SGLangArguments:
self.sglang_config = _convert_str_dict(json.loads(self.sglang_config))
@dataclass
class KTransformersArguments:
r"""Arguments pertaining to the KT training."""
use_kt: bool = field(
default=False,
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
)
kt_optimize_rule: Optional[str] = field(
default=None,
metadata={"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."},
)
cpu_infer: Optional[int] = field(
default=32,
metadata={"help": "Number Of CPU Cores Used For Computation."},
)
chunk_size: Optional[int] = field(
default=8192,
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
)
mode: Optional[str] = field(
default="normal",
metadata={"help": "Normal Or Long_Context For Llama Models."},
)
kt_maxlen: int = field(
default=4096,
metadata={"help": "Maximum Sequence (Prompt + Response) Length Of The KT Engine."},
)
kt_use_cuda_graph: bool = field(
default=True,
metadata={"help": "Whether To Use CUDA Graphs For The KT Engine."},
)
kt_mode: str = field(
default="normal",
metadata={"help": "Normal Or Long_Context Mode For The KT Engine."},
)
kt_force_think: bool = field(
default=False,
metadata={"help": "Force-Think Toggle For The KT Engine."},
)
@dataclass
class ModelArguments(
SGLangArguments, VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments
SGLangArguments, VllmArguments, KTransformersArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments
):
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.