[train] KTransformers SFT as backend engine for LLaMA-Factory (#9400)

Co-authored-by: jimmy128 <jimmy128@noreply.gitcode.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Peilin Li
2025-11-04 15:54:12 +08:00
committed by GitHub
parent 3ae15da9c0
commit 934b3084ee
37 changed files with 2006 additions and 16 deletions

View File

@@ -20,6 +20,8 @@ from peft import LoraConfig, LoraModel, OFTConfig, PeftModel, TaskType, get_peft
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras import logging
from ..extras.constants import EngineName
from .model_utils.ktransformers import get_kt_peft_model, load_kt_peft_model
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
@@ -164,6 +166,10 @@ def _setup_lora_tuning(
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable = False
if model_args.use_kt:
assert len(model_args.adapter_name_or_path) == 1, "Up to now, KTransformers model only accepts a single adapter, for more features, you can contact with us."
is_mergeable = False
if model_args.use_unsloth:
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
is_mergeable = False
@@ -182,6 +188,10 @@ def _setup_lora_tuning(
"token": model_args.hf_hub_token,
}
if model_args.use_kt:
if model_args.infer_backend != EngineName.KT:
raise ValueError("We should use ktransformers as backend to infer the adapter fine-tuned by ktransformers.")
for adapter in adapter_to_merge:
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload()
@@ -190,7 +200,9 @@ def _setup_lora_tuning(
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth:
if model_args.use_kt:
model = load_kt_peft_model(model_args, model)
elif model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
@@ -203,6 +215,16 @@ def _setup_lora_tuning(
else:
target_modules = finetuning_args.lora_target
if model_args.use_kt:
new_list = []
for m in target_modules:
if m in ('down_proj', 'up_proj', 'gate_proj'):
new_list.extend([f"mlp.{m}", f"shared_experts.{m}"])
elif m not in ('generate_linear', 'orig_module', 'prefill_linear'):
new_list.append(m)
target_modules[:] = new_list
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
@@ -245,7 +267,21 @@ def _setup_lora_tuning(
"modules_to_save": finetuning_args.additional_target,
}
if model_args.use_unsloth:
if model_args.use_kt:
if finetuning_args.finetuning_type == "oft":
raise ValueError("KTransformers is currently not supported for OFT.")
if finetuning_args.finetuning_type == "lora":
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
else:
raise ValueError("KTransformers is currently only supported for LoRA.")
model = get_kt_peft_model(model, peft_config)
print(f"KT_model:{model}")
elif model_args.use_unsloth:
if finetuning_args.finetuning_type == "oft":
raise ValueError("Unsloth is currently not supported for OFT.")

View File

@@ -31,6 +31,7 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from .adapter import init_adapter
from .model_utils.ktransformers import load_kt_pretrained_model
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
@@ -143,7 +144,11 @@ def load_model(
model = None
lazy_load = False
if model_args.use_unsloth:
if model_args.use_kt:
from ktransformers.sft.monkey_patch_torch_module import install_patch
install_patch()
model = load_kt_pretrained_model(config, model_args)
elif model_args.use_unsloth:
if model_args.adapter_name_or_path is not None:
lazy_load = True
elif is_trainable:

View File

@@ -0,0 +1,159 @@
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util as _u
from typing import TYPE_CHECKING, Any, Optional
import torch
from ...extras import logging
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from ...hparams import FinetuningArguments, ModelArguments
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
KT_AVAILABLE = _u.find_spec("ktransformers") is not None
if KT_AVAILABLE:
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.server.config.config import Config
from ktransformers.sft.lora import inject_lora_layer
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader
from ktransformers.util.globals import GLOBAL_CONFIG
from ktransformers.util.utils import load_weights
logger = logging.get_logger(__name__)
def _get_kt_kwargs(
config: "PretrainedConfig",
model_name_or_path: str,
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
) -> dict[str, Any]:
return {
"model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length or 4096,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"full_finetuning": finetuning_args.finetuning_type == "full",
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": model_args.trust_remote_code,
"use_gradient_checkpointing": "ktransformers",
}
def load_kt_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments"
) -> Optional["PreTrainedModel"]:
r"""Optionally load pretrained model with KTransformers. Used in training."""
custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
}
Config().cpu_infer = model_args.cpu_infer
Config().chunk_size = model_args.chunk_size
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code)
if model_args.mode == 'long_context':
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
else:
torch.set_default_dtype(config.torch_dtype)
with torch.device("meta"):
if config.architectures[0] in custom_models:
print("using custom modeling_xxx.py.")
if (
"Qwen2Moe" in config.architectures[0]
): # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2"
if "Llama" in config.architectures[0]:
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
model = custom_models[config.architectures[0]](config)
else:
attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True, attn_implementation=attn_implementation
)
optimize_config_path = model_args.kt_optimize_rule
gguf_path = model_args.model_name_or_path
assert optimize_config_path is not None, "optimize_config_path must be provided (path to YAML rules file)."
assert gguf_path is not None, "gguf_path must be provided (path to a folder or .gguf file)."
GLOBAL_CONFIG._config["mod"] = "infer"
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
return model
def get_kt_peft_model(
model: "PreTrainedModel", peft_kwargs: dict[str, Any]
) -> "PreTrainedModel":
r"""Get the peft model for the pretrained model with KTransformers. Used in training."""
from ktransformers.sft.peft_utils.mapping import get_peft_model
return get_peft_model(model, peft_kwargs)
def load_kt_peft_model(
model_args: "ModelArguments", model: "PreTrainedModel",
) -> "PreTrainedModel":
r"""Load peft model with KTransformers. Used in both training and inference."""
load_adapter_name_or_path = model_args.adapter_name_or_path[0]
if load_adapter_name_or_path.endswith('.gguf'):
inject_lora_layer(model, load_adapter_name_or_path)
adapter_gguf_loader = GGUFLoader(load_adapter_name_or_path)
load_weights(model, adapter_gguf_loader, adapter_gguf=True)
model.train()
else:
inject_lora_layer(model, load_adapter_name_or_path)
adapter_loader = SafeTensorLoader(load_adapter_name_or_path)
device = next(model.parameters()).device
for key in adapter_loader.tensor_file_map.keys():
try:
tensor = adapter_loader.load_tensor(key, device=device)
model_key = key.replace("base_model.model.", "")
model_key = model_key.replace(".weight", ".default.weight")
model_key = model_key.replace(".default.default.weight", ".default.weight")
param = model.get_parameter(model_key)
param.data.copy_(tensor.data)
print(f"Loaded adapter weight: {key} -> {model_key}")
except AttributeError:
print(f"Skipping {key}: not a model parameter")
except KeyError:
print(f"Key not found in model: {model_key} (original: {key})")
return model