mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 02:00:36 +08:00
Compare commits
4 Commits
591fc9ed02
...
231756a5bf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
231756a5bf | ||
|
|
2c4fb3c97e | ||
|
|
2b6f16f261 | ||
|
|
f17efde693 |
42
examples/ascend/qwen3vlmoe_lora_sft_fsdp.yaml
Normal file
42
examples/ascend/qwen3vlmoe_lora_sft_fsdp.yaml
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen3-VL-30B-A3B-Instruct
|
||||||
|
image_max_pixels: 262144
|
||||||
|
video_max_pixels: 16384
|
||||||
|
trust_remote_code: true
|
||||||
|
use_kernels: true # replaced kernels: [NpuRMSNormKernel, NpuRoPEKernel, NpuQwen3VLMoEFusedMoEKernel]
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
|
lora_target: all
|
||||||
|
disable_gradient_checkpointing: false
|
||||||
|
flash_attn: disabled
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: alpaca_zh_demo, alpaca_en_demo
|
||||||
|
template: qwen3_vl
|
||||||
|
cutoff_len: 1024
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/qwen3vlmoe/lora/sft
|
||||||
|
logging_steps: 1
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: true
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 8
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
seed: 1234
|
||||||
@@ -15,6 +15,7 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@@ -77,11 +78,18 @@ class VllmEngine(BaseEngine):
|
|||||||
"tensor_parallel_size": get_device_count() or 1,
|
"tensor_parallel_size": get_device_count() or 1,
|
||||||
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
||||||
"disable_log_stats": True,
|
"disable_log_stats": True,
|
||||||
"disable_log_requests": True,
|
|
||||||
"enforce_eager": model_args.vllm_enforce_eager,
|
"enforce_eager": model_args.vllm_enforce_eager,
|
||||||
"enable_lora": model_args.adapter_name_or_path is not None,
|
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||||
"max_lora_rank": model_args.vllm_max_lora_rank,
|
"max_lora_rank": model_args.vllm_max_lora_rank,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
|
||||||
|
if version.parse(vllm.__version__) <= version.parse("0.10.0"):
|
||||||
|
engine_args["disable_log_requests"] = True
|
||||||
|
else:
|
||||||
|
engine_args["enable_log_requests"] = False
|
||||||
|
|
||||||
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
|
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||||
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
|
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
|
||||||
|
|
||||||
|
|||||||
@@ -174,6 +174,10 @@ class BaseModelArguments:
|
|||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use KV cache in generation."},
|
metadata={"help": "Whether or not to use KV cache in generation."},
|
||||||
)
|
)
|
||||||
|
use_v1_kernels: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to use high-performance kernels in training."},
|
||||||
|
)
|
||||||
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
|
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
|
||||||
default="auto",
|
default="auto",
|
||||||
metadata={"help": "Data type for model weights and activations at inference."},
|
metadata={"help": "Data type for model weights and activations at inference."},
|
||||||
|
|||||||
@@ -213,6 +213,17 @@ def load_model(
|
|||||||
else:
|
else:
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
# Borrowing the kernel plugins ability of v1 to temporarily apply the NPU fusion operator to v0,
|
||||||
|
# it is turned off by default, and can be discarded after the transition period ends.
|
||||||
|
if model_args.use_v1_kernels and is_trainable:
|
||||||
|
logger.warning_rank0(
|
||||||
|
"You are try to using future feature about kernels, please note that this feature "
|
||||||
|
"is not supported for all models. If get any error, please disable this feature, or report the issue."
|
||||||
|
)
|
||||||
|
from ..v1.plugins.model_plugins.kernels.registry import apply_available_kernels
|
||||||
|
|
||||||
|
model = apply_available_kernels(model)
|
||||||
|
|
||||||
trainable_params, all_param = count_parameters(model)
|
trainable_params, all_param = count_parameters(model)
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
param_stats = (
|
param_stats = (
|
||||||
|
|||||||
@@ -11,3 +11,103 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import re
|
||||||
|
import types
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
from .....extras.types import HFModel
|
||||||
|
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||||
|
from ..constants import DeviceType, KernelType
|
||||||
|
from ..registry import MetaMoEKernel
|
||||||
|
|
||||||
|
|
||||||
|
class GmmFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, weight, group_list):
|
||||||
|
ctx.save_for_backward(x, weight)
|
||||||
|
ctx.group_list = group_list
|
||||||
|
|
||||||
|
fwd_output = torch_npu.npu_grouped_matmul(
|
||||||
|
[x], [weight], bias=None, group_list=group_list, split_item=2, group_type=0, group_list_type=1
|
||||||
|
)[0]
|
||||||
|
return fwd_output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
input_tensor, weight = ctx.saved_tensors
|
||||||
|
group_list = ctx.group_list
|
||||||
|
|
||||||
|
weight = torch.transpose(weight, 1, 2)
|
||||||
|
grad_input = torch_npu.npu_grouped_matmul(
|
||||||
|
[grad_output], [weight], bias=None, group_list=group_list, split_item=2, group_type=0, group_list_type=1
|
||||||
|
)[0]
|
||||||
|
grad_weight = torch_npu.npu_grouped_matmul(
|
||||||
|
[input_tensor.T],
|
||||||
|
[grad_output],
|
||||||
|
bias=None,
|
||||||
|
group_list=group_list,
|
||||||
|
split_item=3,
|
||||||
|
group_type=2,
|
||||||
|
group_list_type=1,
|
||||||
|
)[0]
|
||||||
|
return grad_input, grad_weight, None
|
||||||
|
|
||||||
|
|
||||||
|
def npu_group_gemm(x, weight, group_list):
|
||||||
|
output = GmmFunction.apply(x, weight, group_list)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def npu_experts_qwen3vlmoe_forward(
|
||||||
|
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
||||||
|
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
|
||||||
|
hidden_states, router_indices.to(torch.int32)
|
||||||
|
)
|
||||||
|
tokens_per_expert = torch.histc(router_indices, bins=self.num_experts, min=0, max=self.num_experts)
|
||||||
|
intermediate_hidden_states = npu_group_gemm(permuted_hidden_states, self.gate_up_proj, tokens_per_expert)
|
||||||
|
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
|
||||||
|
output = npu_group_gemm(intermediate_activations, self.down_proj, tokens_per_expert)
|
||||||
|
next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights)
|
||||||
|
next_states = next_states.view(batch_size, -1, self.hidden_size)
|
||||||
|
return next_states
|
||||||
|
|
||||||
|
|
||||||
|
def npu_moe_block_qwen3vlmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
||||||
|
router_logits = self.gate(hidden_states)
|
||||||
|
routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
|
||||||
|
routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||||
|
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
||||||
|
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||||
|
hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size)
|
||||||
|
routed_out = self.experts(hidden_states, routing_weights, router_indices)
|
||||||
|
return routed_out
|
||||||
|
|
||||||
|
|
||||||
|
class NpuQwen3VLMoEFusedMoEKernel(MetaMoEKernel):
|
||||||
|
type = KernelType.MOE
|
||||||
|
device = DeviceType.NPU
|
||||||
|
npu_experts_kernel = npu_experts_qwen3vlmoe_forward
|
||||||
|
npu_moe_block_kernel = npu_moe_block_qwen3vlmoe_forward
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply(cls, model, **kwargs) -> HFModel:
|
||||||
|
if not is_torch_npu_available():
|
||||||
|
return model
|
||||||
|
|
||||||
|
npu_experts_pattern = re.compile("Qwen3VLMoeTextExperts", re.IGNORECASE)
|
||||||
|
npu_moe_block_pattern = re.compile("Qwen3VLMoeTextSparseMoeBlock", re.IGNORECASE)
|
||||||
|
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
if re.search(npu_experts_pattern, module.__class__.__name__):
|
||||||
|
module.forward = types.MethodType(cls.npu_experts_kernel, module)
|
||||||
|
elif re.search(npu_moe_block_pattern, module.__class__.__name__):
|
||||||
|
module.forward = types.MethodType(cls.npu_moe_block_kernel, module)
|
||||||
|
return model
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import torch
|
|||||||
from .....extras.types import HFModel
|
from .....extras.types import HFModel
|
||||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||||
from ..constants import DeviceType, KernelType
|
from ..constants import DeviceType, KernelType
|
||||||
from ..registry import KERNEL_REGISTRY, MetaSwiGluKernel
|
from ..registry import MetaSwiGluKernel
|
||||||
|
|
||||||
|
|
||||||
def _npu_swiglu_forward(self, hidden_state):
|
def _npu_swiglu_forward(self, hidden_state):
|
||||||
@@ -31,25 +31,85 @@ def _npu_swiglu_forward(self, hidden_state):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _npu_swiglu_glm4_forward(self, hidden_states):
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
up_states = self.gate_up_proj(hidden_states)
|
||||||
|
gate, up_states = up_states.chunk(2, dim=-1)
|
||||||
|
return self.down_proj(torch_npu.npu_swiglu(torch.cat((gate, up_states), dim=-1), dim=-1))
|
||||||
|
|
||||||
|
|
||||||
|
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
gate_proj = self.gate_proj(hidden_states)
|
||||||
|
if self.activation_sparsity > 0.0:
|
||||||
|
gate_proj = self._gaussian_topk(gate_proj)
|
||||||
|
down_proj = self.down_proj(
|
||||||
|
torch_npu.npu_swiglu(torch.cat((gate_proj, self.up_proj(hidden_states)), dim=-1), dim=-1)
|
||||||
|
)
|
||||||
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
class NpuSwiGluKernel(MetaSwiGluKernel):
|
class NpuSwiGluKernel(MetaSwiGluKernel):
|
||||||
|
type = KernelType.SWIGLU
|
||||||
device = DeviceType.NPU
|
device = DeviceType.NPU
|
||||||
kernel = _npu_swiglu_forward
|
kernel = _npu_swiglu_forward
|
||||||
|
|
||||||
@classmethod
|
# Don't apply the kernel to the following modules
|
||||||
def register_kernel(cls, kernel_type=KernelType.SWIGLU, device_type=DeviceType.NPU):
|
expect_modules = frozenset(
|
||||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
{
|
||||||
|
"Qwen3VLMoeTextMLP",
|
||||||
|
"Qwen3VLTextMLP",
|
||||||
|
"Qwen3OmniMoeThinkerTextMLP",
|
||||||
|
"Qwen3OmniMoeMLP",
|
||||||
|
"Qwen3OmniMoeTalkerTextMLP",
|
||||||
|
"Qwen3OmniMoeCode2WavMlp",
|
||||||
|
"Qwen3NextMLP",
|
||||||
|
"Qwen3MoeMLP",
|
||||||
|
"Qwen3MLP",
|
||||||
|
"Qwen2MLP",
|
||||||
|
"Qwen2MoeMLP",
|
||||||
|
"Qwen2_5_VLMLP",
|
||||||
|
"Qwen2_5OmniMLP",
|
||||||
|
"Llama4TextMLP",
|
||||||
|
"LlamaMLP",
|
||||||
|
"Glm4MLP",
|
||||||
|
"Glm4MoeMLP",
|
||||||
|
"Glm4vMoeTextMLP",
|
||||||
|
"Gemma3MLP",
|
||||||
|
"Gemma2MLP",
|
||||||
|
"Gemma3nTextMLP",
|
||||||
|
"Phi3MLP",
|
||||||
|
"DeepseekV2MLP",
|
||||||
|
"DeepseekV3MLP",
|
||||||
|
"SeedOssMLP",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model, **kwargs) -> "HFModel":
|
def apply(cls, model, **kwargs) -> "HFModel":
|
||||||
if not is_torch_npu_available():
|
if not is_torch_npu_available():
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
# Mapping of specific mlp modules to their corresponding kernel implementations
|
||||||
|
kernel_mapping = {
|
||||||
|
"Glm4MLP": _npu_swiglu_glm4_forward,
|
||||||
|
"Glm4vTextMLP": _npu_swiglu_glm4_forward,
|
||||||
|
"Phi3MLP": _npu_swiglu_glm4_forward,
|
||||||
|
"Gemma3nTextMLP": _npu_swiglu_gemma3ntext_forward,
|
||||||
|
}
|
||||||
|
|
||||||
swiglu_pattern = re.compile("MLP", re.IGNORECASE)
|
swiglu_pattern = re.compile("MLP", re.IGNORECASE)
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
# Match any module whose class name contains "RMSNorm"
|
# Match any module whose class name contains "MLP"
|
||||||
if re.search(swiglu_pattern, module.__class__.__name__):
|
if (
|
||||||
|
re.search(swiglu_pattern, module.__class__.__name__)
|
||||||
|
and module.__class__.__name__ in cls.expect_modules
|
||||||
|
):
|
||||||
# Bind function as an instance method to preserve `self` semantics
|
# Bind function as an instance method to preserve `self` semantics
|
||||||
# and replace the original forward
|
# and replace the original forward
|
||||||
module.forward = types.MethodType(cls.kernel, module)
|
kernel_func = kernel_mapping.get(module.__class__.__name__, _npu_swiglu_forward)
|
||||||
|
module.forward = types.MethodType(kernel_func, module)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, ABCMeta, abstractmethod
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
from ....extras.types import HFModel
|
from ....extras.types import HFModel
|
||||||
@@ -61,18 +61,67 @@ class KernelRegistry:
|
|||||||
KERNEL_REGISTRY = KernelRegistry()
|
KERNEL_REGISTRY = KernelRegistry()
|
||||||
|
|
||||||
|
|
||||||
class MetaKernel(ABC):
|
class AutoRegisterKernelMeta(ABCMeta):
|
||||||
|
"""Metaclass that automatically registers kernel classes upon creation.
|
||||||
|
|
||||||
|
This metaclass checks if a newly created class has both `type` and `device`
|
||||||
|
attributes defined. If so, it automatically registers the kernel in the
|
||||||
|
global KERNEL_REGISTRY, eliminating the need for manual registration.
|
||||||
|
|
||||||
|
To disable auto-registration for a specific class, set `auto_register = False`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(mcs, name, bases, namespace, **kwargs):
|
||||||
|
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
|
||||||
|
|
||||||
|
# Check if auto-registration is disabled
|
||||||
|
auto_register = namespace.get("auto_register", True)
|
||||||
|
|
||||||
|
# Only auto-register if the class has both type and device attributes defined
|
||||||
|
# and they are not None (skip base classes like MetaKernel itself)
|
||||||
|
# and auto_register is True
|
||||||
|
kernel_type = namespace.get("type")
|
||||||
|
device_type = namespace.get("device")
|
||||||
|
|
||||||
|
if auto_register and kernel_type is not None and device_type is not None:
|
||||||
|
# Auto-register this kernel
|
||||||
|
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||||
|
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
|
||||||
|
"""Base class for all kernel implementations.
|
||||||
|
|
||||||
|
Subclasses are automatically registered when they define both `type` and `device`
|
||||||
|
attributes. To disable auto-registration, set `auto_register = False`.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type: The kernel type (e.g., KernelType.RMSNORM). Must be set in subclasses.
|
||||||
|
device: The device type (e.g., DeviceType.NPU). Must be set in subclasses.
|
||||||
|
kernel: The actual kernel function or implementation.
|
||||||
|
auto_register: Set to False to disable automatic registration (default: True).
|
||||||
|
"""
|
||||||
|
|
||||||
type: Optional[KernelType] = None
|
type: Optional[KernelType] = None
|
||||||
device: Optional[DeviceType] = None
|
device: Optional[DeviceType] = None
|
||||||
kernel: Optional[Callable] = None
|
kernel: Optional[Callable] = None
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register_kernel(cls, kernel_type: KernelType, device_type: DeviceType):
|
|
||||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||||
|
"""Apply the kernel to the model.
|
||||||
|
|
||||||
|
This method should check if the kernel can be applied (e.g., dependencies
|
||||||
|
are installed, target modules exist) and perform the kernel replacement.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The HuggingFace model to optimize.
|
||||||
|
**kwargs: Additional arguments for kernel application.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The optimized model (may be the same object with modifications).
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@@ -106,16 +155,75 @@ class MetaMoEKernel(MetaKernel):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def discover_kernels(model: HFModel) -> list[MetaKernel]:
|
def _ensure_kernels_loaded() -> None:
|
||||||
"""Discover and construct MetaKernel instances for the current model/device.
|
"""Ensure all kernel implementations are imported and registered.
|
||||||
|
|
||||||
This is a placeholder to be implemented: it should inspect the runtime
|
This function dynamically imports all kernel implementation modules to trigger
|
||||||
environment (device type, available extensions, model architecture) and
|
their auto-registration. Python's module system ensures each module is only
|
||||||
return an ordered list of MetaKernel instances to be applied. Each returned
|
executed once (cached in sys.modules), so repeated calls are safe and fast.
|
||||||
MetaKernel must encapsulate its own replacement logic in `apply`.
|
|
||||||
"""
|
"""
|
||||||
# TODO: Implement auto discovery logic based on registry and device capabilities.
|
# List of kernel module paths to import
|
||||||
return []
|
kernel_modules = [
|
||||||
|
"rms_norm.npu_rms_norm",
|
||||||
|
"rope.npu_rope",
|
||||||
|
"mlp.npu_swiglu",
|
||||||
|
"mlp.npu_fused_moe",
|
||||||
|
# Add new kernel modules here as they are created
|
||||||
|
]
|
||||||
|
|
||||||
|
# Import each module to trigger kernel registration
|
||||||
|
# Python's import system caches modules, so this is fast on subsequent calls
|
||||||
|
for module_name in kernel_modules:
|
||||||
|
try:
|
||||||
|
__import__(f"{__package__}.{module_name}", fromlist=["*"])
|
||||||
|
except ImportError:
|
||||||
|
# Silently ignore import errors (e.g., missing dependencies like torch_npu)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||||
|
"""Discover and return all kernel classes registered for the current device.
|
||||||
|
|
||||||
|
This function inspects the runtime environment (device type) and returns
|
||||||
|
all MetaKernel classes registered for that device. Each kernel's `apply()`
|
||||||
|
method is responsible for checking if it can actually be applied (e.g.,
|
||||||
|
required dependencies are installed, target modules exist in the model).
|
||||||
|
|
||||||
|
The function automatically discovers all kernels registered in KERNEL_REGISTRY
|
||||||
|
without requiring manual enumeration. On first call, it dynamically imports
|
||||||
|
all kernel implementation modules to trigger their auto-registration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The HuggingFace model to apply kernels to.
|
||||||
|
TODO: implement the kernel route detection logic by model structure.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of MetaKernel classes available for the current device.
|
||||||
|
"""
|
||||||
|
# Ensure all kernel modules are imported to trigger registration
|
||||||
|
_ensure_kernels_loaded()
|
||||||
|
|
||||||
|
discovered_kernels: list[type[MetaKernel]] = []
|
||||||
|
|
||||||
|
# Detect current device type
|
||||||
|
accelerator = get_available_accelerator()
|
||||||
|
try:
|
||||||
|
device_type = DeviceType(accelerator.type)
|
||||||
|
except ValueError:
|
||||||
|
# Unknown device type, return empty list
|
||||||
|
return discovered_kernels
|
||||||
|
|
||||||
|
# Skip CPU as it typically doesn't have optimized kernels
|
||||||
|
if device_type == DeviceType.CPU:
|
||||||
|
return discovered_kernels
|
||||||
|
|
||||||
|
# Iterate through registry and collect all kernels for current device
|
||||||
|
for kernel_type, devices in KERNEL_REGISTRY._registry.items():
|
||||||
|
kernel_cls = devices.get(device_type)
|
||||||
|
if kernel_cls is not None:
|
||||||
|
discovered_kernels.append(kernel_cls)
|
||||||
|
|
||||||
|
return discovered_kernels
|
||||||
|
|
||||||
|
|
||||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
|
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
|
||||||
@@ -136,3 +244,10 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead."
|
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
|
||||||
|
"""Apply all available kernels to the model."""
|
||||||
|
for kernel in discover_kernels(model):
|
||||||
|
model = apply_kernel(model, kernel, **kwargs)
|
||||||
|
return model
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import types
|
|||||||
from .....extras.types import HFModel
|
from .....extras.types import HFModel
|
||||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||||
from ..constants import DeviceType, KernelType
|
from ..constants import DeviceType, KernelType
|
||||||
from ..registry import KERNEL_REGISTRY, MetaRMSNormKernel
|
from ..registry import MetaRMSNormKernel
|
||||||
|
|
||||||
|
|
||||||
def _npu_rms_forward(self, hidden_states):
|
def _npu_rms_forward(self, hidden_states):
|
||||||
@@ -38,14 +38,10 @@ def _npu_rms_forward(self, hidden_states):
|
|||||||
class NpuRMSNormKernel(MetaRMSNormKernel):
|
class NpuRMSNormKernel(MetaRMSNormKernel):
|
||||||
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
||||||
|
|
||||||
|
type = KernelType.RMSNORM
|
||||||
device = DeviceType.NPU
|
device = DeviceType.NPU
|
||||||
kernel = _npu_rms_forward
|
kernel = _npu_rms_forward
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register_kernel(cls, kernel_type=KernelType.RMSNORM, device_type=DeviceType.NPU):
|
|
||||||
"""Register the NPU RMSNorm forward implementation to the global registry."""
|
|
||||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model, **kwargs) -> HFModel:
|
def apply(cls, model, **kwargs) -> HFModel:
|
||||||
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import torch
|
|||||||
from .....extras.types import HFModel
|
from .....extras.types import HFModel
|
||||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||||
from ..constants import DeviceType, KernelType
|
from ..constants import DeviceType, KernelType
|
||||||
from ..registry import KERNEL_REGISTRY, MetaRoPEKernel
|
from ..registry import MetaRoPEKernel
|
||||||
|
|
||||||
|
|
||||||
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
@@ -51,13 +51,10 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
|
|||||||
|
|
||||||
|
|
||||||
class NpuRoPEKernel(MetaRoPEKernel):
|
class NpuRoPEKernel(MetaRoPEKernel):
|
||||||
|
type = KernelType.ROPE
|
||||||
device = DeviceType.NPU
|
device = DeviceType.NPU
|
||||||
kernel = _apply_rotary_pos_emb
|
kernel = _apply_rotary_pos_emb
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
|
|
||||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model, **kwargs) -> "HFModel":
|
def apply(cls, model, **kwargs) -> "HFModel":
|
||||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||||
@@ -88,12 +85,16 @@ class NpuRoPEKernel(MetaRoPEKernel):
|
|||||||
|
|
||||||
|
|
||||||
class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
|
class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
|
||||||
|
"""Qwen2-VL specific RoPE kernel - not auto-registered.
|
||||||
|
|
||||||
|
This kernel is for specific models (Qwen2-VL) and should be manually
|
||||||
|
applied when needed rather than auto-discovered.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type = KernelType.ROPE
|
||||||
device = DeviceType.NPU
|
device = DeviceType.NPU
|
||||||
kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
|
kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
|
||||||
|
auto_register = False # Disable auto-registration for model-specific kernel
|
||||||
@classmethod
|
|
||||||
def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
|
|
||||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model, **kwargs) -> "HFModel":
|
def apply(cls, model, **kwargs) -> "HFModel":
|
||||||
|
|||||||
@@ -42,3 +42,23 @@ class TestKernelPlugin(unittest.TestCase):
|
|||||||
|
|
||||||
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
|
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
|
||||||
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
||||||
|
|
||||||
|
|
||||||
|
class Test_Use_V1_Kernels(unittest.TestCase):
|
||||||
|
@patch("torch.accelerator.current_accelerator")
|
||||||
|
def test_use_v1_kernels(self, mock_get_accelerator):
|
||||||
|
mock_device = MagicMock()
|
||||||
|
mock_device.type = "npu"
|
||||||
|
mock_get_accelerator.return_value = mock_device
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||||
|
|
||||||
|
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||||
|
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||||
|
|
||||||
|
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels
|
||||||
|
|
||||||
|
model = apply_available_kernels(model)
|
||||||
|
|
||||||
|
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
|
||||||
|
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
||||||
|
|||||||
Reference in New Issue
Block a user