mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-28 10:58:54 +08:00
[refactor] Add KTransformers AMX MoE SFT support via Accelerate (#10430)
Co-authored-by: mrhaoxx <mr.haoxx@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, Literal, Self
|
||||
|
||||
@@ -460,47 +461,81 @@ class SGLangArguments:
|
||||
|
||||
@dataclass
|
||||
class KTransformersArguments:
|
||||
r"""Arguments pertaining to the KT training."""
|
||||
r"""Arguments pertaining to KTransformers AMX MoE SFT training.
|
||||
|
||||
These fields are normalized into the transformers/accelerate KT config before training starts.
|
||||
"""
|
||||
|
||||
use_kt: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
|
||||
metadata={"help": "Whether to use KTransformers AMX MoE backend for SFT training."},
|
||||
)
|
||||
kt_optimize_rule: str | None = field(
|
||||
kt_weight_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
|
||||
},
|
||||
metadata={"help": "Path to pre-quantized INT8 expert weights (.kt files)."},
|
||||
)
|
||||
cpu_infer: int | None = field(
|
||||
default=32,
|
||||
metadata={"help": "Number Of CPU Cores Used For Computation."},
|
||||
kt_expert_checkpoint_path: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to expert checkpoint (safetensors) for online conversion."},
|
||||
)
|
||||
chunk_size: int | None = field(
|
||||
default=8192,
|
||||
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
|
||||
kt_use_lora_experts: bool | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Whether to use GPU-side LoRA Experts."},
|
||||
)
|
||||
mode: str | None = field(
|
||||
default="normal",
|
||||
metadata={"help": "Normal Or Long_Context For Llama Models."},
|
||||
kt_lora_expert_num: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of GPU-side LoRA Experts."},
|
||||
)
|
||||
kt_lora_expert_intermediate_size: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Intermediate size for GPU-side LoRA Experts."},
|
||||
)
|
||||
|
||||
kt_maxlen: int = field(
|
||||
default=4096,
|
||||
metadata={"help": "Maximum Sequence (Prompt + Response) Length Of The KT Engine."},
|
||||
)
|
||||
kt_use_cuda_graph: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether To Use CUDA Graphs For The KT Engine."},
|
||||
)
|
||||
kt_mode: str = field(
|
||||
default="normal",
|
||||
metadata={"help": "Normal Or Long_Context Mode For The KT Engine."},
|
||||
)
|
||||
kt_force_think: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Force-Think Toggle For The KT Engine."},
|
||||
)
|
||||
def get_kt_config_dict(self, finetuning_args: Any, model_max_length: int | None) -> dict[str, Any]:
|
||||
r"""Build KT config values from LLaMA-Factory model and LoRA arguments."""
|
||||
kt_config = {
|
||||
"kt_lora_rank": getattr(finetuning_args, "lora_rank", None),
|
||||
"kt_lora_alpha": getattr(finetuning_args, "lora_alpha", None),
|
||||
"kt_weight_path": self.kt_weight_path,
|
||||
"kt_expert_checkpoint_path": self.kt_expert_checkpoint_path,
|
||||
"kt_model_max_length": model_max_length,
|
||||
"kt_use_lora_experts": self.kt_use_lora_experts,
|
||||
"kt_lora_expert_num": self.kt_lora_expert_num,
|
||||
"kt_lora_expert_intermediate_size": self.kt_lora_expert_intermediate_size,
|
||||
}
|
||||
return {key: value for key, value in kt_config.items() if value is not None}
|
||||
|
||||
def apply_kt_config(self, finetuning_args: Any, training_args: Any, model_max_length: int | None) -> None:
|
||||
r"""Apply LLaMA-Factory KT args to transformers/accelerate KT integration points."""
|
||||
if not self.use_kt:
|
||||
return
|
||||
|
||||
kt_config = self.get_kt_config_dict(finetuning_args, model_max_length)
|
||||
env_mapping = {
|
||||
"kt_weight_path": "ACCELERATE_KT_WEIGHT_PATH",
|
||||
"kt_expert_checkpoint_path": "ACCELERATE_KT_EXPERT_CHECKPOINT_PATH",
|
||||
"kt_model_max_length": "ACCELERATE_KT_MODEL_MAX_LENGTH",
|
||||
"kt_lora_rank": "ACCELERATE_KT_LORA_RANK",
|
||||
"kt_lora_alpha": "ACCELERATE_KT_LORA_ALPHA",
|
||||
"kt_use_lora_experts": "ACCELERATE_KT_USE_LORA_EXPERTS",
|
||||
"kt_lora_expert_num": "ACCELERATE_KT_LORA_EXPERT_NUM",
|
||||
"kt_lora_expert_intermediate_size": "ACCELERATE_KT_LORA_EXPERT_INTERMEDIATE_SIZE",
|
||||
}
|
||||
for key, env_key in env_mapping.items():
|
||||
value = kt_config.get(key)
|
||||
if value is not None:
|
||||
os.environ[env_key] = str(value)
|
||||
|
||||
hf_kt = getattr(training_args, "hf_kt_config", None)
|
||||
if hf_kt is None or not hasattr(hf_kt, "_kt_config") or not isinstance(hf_kt._kt_config, dict):
|
||||
return
|
||||
|
||||
hf_kt._kt_config.update(kt_config)
|
||||
gc_enabled = getattr(training_args, "gradient_checkpointing", False) or not getattr(
|
||||
self, "disable_gradient_checkpointing", True
|
||||
)
|
||||
if gc_enabled:
|
||||
hf_kt._kt_config.setdefault("kt_share_cache_pool", True)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -186,13 +186,16 @@ def _verify_model_args(
|
||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||
|
||||
|
||||
|
||||
def _check_extra_dependencies(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
training_args: Optional["TrainingArguments"] = None,
|
||||
) -> None:
|
||||
if model_args.use_kt:
|
||||
check_version("ktransformers", mandatory=True)
|
||||
check_version("kt-kernel", mandatory=True)
|
||||
check_version("transformers-kt", mandatory=True)
|
||||
check_version("accelerate-kt", mandatory=True)
|
||||
|
||||
if model_args.use_unsloth:
|
||||
check_version("unsloth", mandatory=True)
|
||||
@@ -510,6 +513,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
)
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
if model_args.use_kt:
|
||||
model_args.apply_kt_config(finetuning_args, training_args, model_args.model_max_length)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user