[refactor] Add KTransformers AMX MoE SFT support via Accelerate (#10430)

Co-authored-by: mrhaoxx <mr.haoxx@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Peilin Li
2026-05-01 01:47:58 +08:00
committed by GitHub
parent 6b08b948c9
commit 887ee2b121
39 changed files with 287 additions and 1968 deletions

View File

@@ -21,7 +21,6 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras import logging
from ..extras.constants import EngineName
from .model_utils.ktransformers import get_kt_peft_model, load_kt_peft_model
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
@@ -188,12 +187,6 @@ def _setup_lora_tuning(
"token": model_args.hf_hub_token,
}
if model_args.use_kt:
if model_args.infer_backend != EngineName.KT:
raise ValueError(
"We should use ktransformers as backend to infer the adapter fine-tuned by ktransformers."
)
for adapter in adapter_to_merge:
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload()
@@ -202,9 +195,7 @@ def _setup_lora_tuning(
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training
if model_args.use_kt:
model = load_kt_peft_model(model_args, model)
elif model_args.use_unsloth:
if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
@@ -217,16 +208,6 @@ def _setup_lora_tuning(
else:
target_modules = finetuning_args.lora_target
if model_args.use_kt:
new_list = []
for m in target_modules:
if m in ("down_proj", "up_proj", "gate_proj"):
new_list.extend([f"mlp.{m}", f"shared_experts.{m}"])
elif m not in ("generate_linear", "orig_module", "prefill_linear"):
new_list.append(m)
target_modules[:] = new_list
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
@@ -270,19 +251,11 @@ def _setup_lora_tuning(
}
if model_args.use_kt:
if finetuning_args.finetuning_type == "oft":
raise ValueError("KTransformers is currently not supported for OFT.")
if finetuning_args.finetuning_type == "lora":
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
else:
raise ValueError("KTransformers is currently only supported for LoRA.")
if finetuning_args.finetuning_type != "lora":
raise ValueError("KTransformers only supports LoRA finetuning.")
model = get_kt_peft_model(model, peft_config)
print(f"KT_model:{model}")
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, **peft_kwargs)
model = get_peft_model(model, peft_config)
elif model_args.use_unsloth:
if finetuning_args.finetuning_type == "oft":
raise ValueError("Unsloth is currently not supported for OFT.")