[v1] Refactor kernel plugin (#9669)

Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
浮梦
2025-12-31 18:26:48 +08:00
committed by GitHub
parent 4e1d69579a
commit 16735b9e35
19 changed files with 777 additions and 433 deletions

View File

@@ -14,7 +14,7 @@
import torch
from llamafactory.v1.config.model_args import ModelArguments
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
from llamafactory.v1.core.model_loader import ModelLoader
@@ -29,5 +29,23 @@ def test_tiny_qwen():
assert model_loader.model.dtype == torch.bfloat16
def test_tiny_qwen_with_kernel_plugin():
from transformers import Qwen2ForCausalLM
from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import npu_rms_norm_forward
model_args = ModelArguments(
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
)
model_loader = ModelLoader(model_args)
# test enable apply kernel plugin
if hasattr(torch, "npu"):
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
else:
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert isinstance(model_loader.model, Qwen2ForCausalLM)
if __name__ == "__main__":
test_tiny_qwen()
test_tiny_qwen_with_kernel_plugin()