[refactor] Add KTransformers AMX MoE SFT support via Accelerate (#10430)

Co-authored-by: mrhaoxx <mr.haoxx@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Peilin Li
2026-05-01 01:47:58 +08:00
committed by GitHub
parent 6b08b948c9
commit 887ee2b121
39 changed files with 287 additions and 1968 deletions

View File

@@ -16,6 +16,7 @@
# limitations under the License.
import json
import os
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Literal, Self
@@ -460,47 +461,81 @@ class SGLangArguments:
@dataclass
class KTransformersArguments:
r"""Arguments pertaining to the KT training."""
r"""Arguments pertaining to KTransformers AMX MoE SFT training.
These fields are normalized into the transformers/accelerate KT config before training starts.
"""
use_kt: bool = field(
default=False,
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
metadata={"help": "Whether to use KTransformers AMX MoE backend for SFT training."},
)
kt_optimize_rule: str | None = field(
kt_weight_path: str | None = field(
default=None,
metadata={
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
},
metadata={"help": "Path to pre-quantized INT8 expert weights (.kt files)."},
)
cpu_infer: int | None = field(
default=32,
metadata={"help": "Number Of CPU Cores Used For Computation."},
kt_expert_checkpoint_path: str | None = field(
default=None,
metadata={"help": "Path to expert checkpoint (safetensors) for online conversion."},
)
chunk_size: int | None = field(
default=8192,
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
kt_use_lora_experts: bool | None = field(
default=None,
metadata={"help": "Whether to use GPU-side LoRA Experts."},
)
mode: str | None = field(
default="normal",
metadata={"help": "Normal Or Long_Context For Llama Models."},
kt_lora_expert_num: int | None = field(
default=None,
metadata={"help": "Number of GPU-side LoRA Experts."},
)
kt_lora_expert_intermediate_size: int | None = field(
default=None,
metadata={"help": "Intermediate size for GPU-side LoRA Experts."},
)
kt_maxlen: int = field(
default=4096,
metadata={"help": "Maximum Sequence (Prompt + Response) Length Of The KT Engine."},
)
kt_use_cuda_graph: bool = field(
default=True,
metadata={"help": "Whether To Use CUDA Graphs For The KT Engine."},
)
kt_mode: str = field(
default="normal",
metadata={"help": "Normal Or Long_Context Mode For The KT Engine."},
)
kt_force_think: bool = field(
default=False,
metadata={"help": "Force-Think Toggle For The KT Engine."},
)
def get_kt_config_dict(self, finetuning_args: Any, model_max_length: int | None) -> dict[str, Any]:
r"""Build KT config values from LLaMA-Factory model and LoRA arguments."""
kt_config = {
"kt_lora_rank": getattr(finetuning_args, "lora_rank", None),
"kt_lora_alpha": getattr(finetuning_args, "lora_alpha", None),
"kt_weight_path": self.kt_weight_path,
"kt_expert_checkpoint_path": self.kt_expert_checkpoint_path,
"kt_model_max_length": model_max_length,
"kt_use_lora_experts": self.kt_use_lora_experts,
"kt_lora_expert_num": self.kt_lora_expert_num,
"kt_lora_expert_intermediate_size": self.kt_lora_expert_intermediate_size,
}
return {key: value for key, value in kt_config.items() if value is not None}
def apply_kt_config(self, finetuning_args: Any, training_args: Any, model_max_length: int | None) -> None:
r"""Apply LLaMA-Factory KT args to transformers/accelerate KT integration points."""
if not self.use_kt:
return
kt_config = self.get_kt_config_dict(finetuning_args, model_max_length)
env_mapping = {
"kt_weight_path": "ACCELERATE_KT_WEIGHT_PATH",
"kt_expert_checkpoint_path": "ACCELERATE_KT_EXPERT_CHECKPOINT_PATH",
"kt_model_max_length": "ACCELERATE_KT_MODEL_MAX_LENGTH",
"kt_lora_rank": "ACCELERATE_KT_LORA_RANK",
"kt_lora_alpha": "ACCELERATE_KT_LORA_ALPHA",
"kt_use_lora_experts": "ACCELERATE_KT_USE_LORA_EXPERTS",
"kt_lora_expert_num": "ACCELERATE_KT_LORA_EXPERT_NUM",
"kt_lora_expert_intermediate_size": "ACCELERATE_KT_LORA_EXPERT_INTERMEDIATE_SIZE",
}
for key, env_key in env_mapping.items():
value = kt_config.get(key)
if value is not None:
os.environ[env_key] = str(value)
hf_kt = getattr(training_args, "hf_kt_config", None)
if hf_kt is None or not hasattr(hf_kt, "_kt_config") or not isinstance(hf_kt._kt_config, dict):
return
hf_kt._kt_config.update(kt_config)
gc_enabled = getattr(training_args, "gradient_checkpointing", False) or not getattr(
self, "disable_gradient_checkpointing", True
)
if gc_enabled:
hf_kt._kt_config.setdefault("kt_share_cache_pool", True)
@dataclass

View File

@@ -186,13 +186,16 @@ def _verify_model_args(
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
training_args: Optional["TrainingArguments"] = None,
) -> None:
if model_args.use_kt:
check_version("ktransformers", mandatory=True)
check_version("kt-kernel", mandatory=True)
check_version("transformers-kt", mandatory=True)
check_version("accelerate-kt", mandatory=True)
if model_args.use_unsloth:
check_version("unsloth", mandatory=True)
@@ -510,6 +513,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
)
transformers.set_seed(training_args.seed)
if model_args.use_kt:
model_args.apply_kt_config(finetuning_args, training_args, model_args.model_max_length)
return model_args, data_args, training_args, finetuning_args, generating_args