mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-10 08:00:36 +08:00
[v1] Refactor kernel plugin (#9669)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from llamafactory.v1.config.model_args import ModelArguments
|
||||
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
|
||||
from llamafactory.v1.core.model_loader import ModelLoader
|
||||
|
||||
|
||||
@@ -29,5 +29,23 @@ def test_tiny_qwen():
|
||||
assert model_loader.model.dtype == torch.bfloat16
|
||||
|
||||
|
||||
def test_tiny_qwen_with_kernel_plugin():
|
||||
from transformers import Qwen2ForCausalLM
|
||||
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import npu_rms_norm_forward
|
||||
|
||||
model_args = ModelArguments(
|
||||
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
|
||||
)
|
||||
model_loader = ModelLoader(model_args)
|
||||
# test enable apply kernel plugin
|
||||
if hasattr(torch, "npu"):
|
||||
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
|
||||
else:
|
||||
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
|
||||
assert isinstance(model_loader.model, Qwen2ForCausalLM)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tiny_qwen()
|
||||
test_tiny_qwen_with_kernel_plugin()
|
||||
|
||||
Reference in New Issue
Block a user