[npu] add Qwen3.5 support with Partial RoPE and Hybrid Attention (#10421)

Co-authored-by: Curnane <mingliangfu@users.noreply.github.com>
This commit is contained in:
curnane-lab
2026-04-27 23:36:07 +08:00
committed by GitHub
parent 99464b3d03
commit 2092abc217
4 changed files with 224 additions and 43 deletions

View File

@@ -0,0 +1,20 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_version: 2
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer,Qwen3_5VisionBlock
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8 # Change to match your NPU count (e.g., 8 for A2, 16 for A3)
rdzv_backend: static
same_network: true
use_cpu: false

View File

@@ -0,0 +1,47 @@
# Start FSDP2 full fine-tuning on Ascend NPU
# Usage:
# accelerate launch \
# --config_file examples/accelerate/fsdp2_config_qwen35.yaml \
# src/train.py examples/ascend/qwen3_5_full_sft_fsdp2.yaml
#
# Note: Change `num_processes` in fsdp2_config_qwen35.yaml to match your NPU count
### model
model_name_or_path: Qwen/Qwen3.5-4B
trust_remote_code: true
use_v1_kernels: true
flash_attn: fa2
### method
stage: sft
do_train: true
finetuning_type: full
### dataset
dataset: alpaca_en_demo
template: qwen3_5_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Qwen3.5-4B/full/sft
logging_steps: 1
save_steps: 500
max_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 8
gradient_accumulation_steps: 1
learning_rate: 1.0e-5
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 1800
resume_from_checkpoint: null

View File

@@ -23,25 +23,104 @@ Init Phase:
import re import re
import types import types
import torch
import torch.nn.functional as F
from ......accelerator.helper import DeviceType from ......accelerator.helper import DeviceType
from ......utils.types import HFModel from ......utils.types import HFModel
from ...base import BaseKernel from ...base import BaseKernel
from ...registry import register_kernel from ...registry import register_kernel
def npu_rms_norm_forward(self, hidden_states): try:
"""NPU forward implementation for RMSNorm. import torch_npu
except ImportError:
pass
def _should_use_residual_rmsnorm(module):
"""Detect whether the module uses residual RMSNorm parameterization.
Residual RMSNorm uses ``scale = 1.0 + weight`` where weight is initialized to 0,
while standard RMSNorm uses ``scale = weight`` where weight is initialized to 1.
Args: Args:
self: RMSNorm module instance with `weight` and `variance_epsilon`. module (nn.Module): The RMSNorm module to check.
hidden_states (Tensor): Input hidden states tensor, same shape as the baseline.
Returns:
bool: ``True`` if the module uses residual parameterization, ``False`` otherwise.
.. note::
This detection ensures compatibility with future model versions (e.g., Qwen3.6, Qwen4.0)
without hardcoding version numbers. Two methods are used: weight value inspection
(most reliable) and class name pattern matching (backward compatibility).
"""
if hasattr(module, "weight") and module.weight is not None:
weight_mean = module.weight.data.mean().item()
if abs(weight_mean) < 0.3:
return True
class_name = module.__class__.__name__
residual_patterns = ["Qwen3_5", "Qwen3_6", "Qwen4"]
for pattern in residual_patterns:
if pattern in class_name:
return True
return False
def npu_rms_norm_forward(self, hidden_states):
"""NPU forward implementation for standard RMSNorm.
Args:
self (nn.Module): The RMSNorm module instance with ``weight`` and ``variance_epsilon``.
hidden_states (Tensor): Input hidden states tensor.
Returns: Returns:
Tensor: Normalized tensor consistent with the baseline RMSNorm behavior. Tensor: Normalized tensor consistent with the baseline RMSNorm behavior.
""" """
import torch_npu _eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] if hasattr(self, "weight") and self.weight is not None:
if _should_use_residual_rmsnorm(self):
effective_weight = 1.0 + self.weight.float()
else:
effective_weight = self.weight.float()
else:
effective_weight = None
if effective_weight is not None:
return torch_npu.npu_rms_norm(hidden_states, effective_weight.to(hidden_states.dtype), epsilon=_eps)[0]
else:
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=_eps)[0]
def npu_gated_rms_norm_forward(self, hidden_states, gate=None):
"""NPU forward implementation for Gated RMSNorm with high-precision FP32 computation.
This function performs RMSNorm and gated SiLU multiplication in FP32 for numerical
stability. Unlike standard RMSNorm, Gated RMSNorm in Qwen3.5 uses standard
parameterization (``scale = weight`` where weight is initialized to 1), so the
residual weight adjustment (``1.0 + weight``) is not applied here.
Args:
self (nn.Module): The Gated RMSNorm module instance.
hidden_states (Tensor): Input hidden states tensor.
gate (Tensor, optional): Gate tensor for SiLU activation. Defaults to ``None``.
Returns:
Tensor: Output tensor cast back to the original input dtype.
"""
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight.float(), epsilon=_eps)[0]
if gate is not None:
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype)
@register_kernel @register_kernel
@@ -55,12 +134,9 @@ class NpuRMSNormKernel(BaseKernel):
def apply(cls, **kwargs) -> "HFModel": def apply(cls, **kwargs) -> "HFModel":
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules. """Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
Key points: Matches modules whose class name contains "RMSNorm" (case-insensitive) and binds
- Match modules whose class name contains "RMSNorm" (case-insensitive). the appropriate NPU-optimized forward function as an instance method via
- Bind `_npu_rms_forward` as an instance method via `types.MethodType` to ``types.MethodType`` to replace the original ``forward``.
replace the original `forward`.
- Do not modify weights, hyperparameters, or module structure to ensure
numerical behavior and interface consistency.
Args: Args:
**kwargs: Keyword arguments containing the model. **kwargs: Keyword arguments containing the model.
@@ -69,7 +145,7 @@ class NpuRMSNormKernel(BaseKernel):
HFModel: The model with NPU fused RMSNorm. HFModel: The model with NPU fused RMSNorm.
Raises: Raises:
RuntimeError: If torch_npu is not available. RuntimeError: If ``torch_npu`` is not available.
ValueError: If the model is not provided. ValueError: If the model is not provided.
""" """
model = kwargs.get("model") model = kwargs.get("model")
@@ -81,11 +157,11 @@ class NpuRMSNormKernel(BaseKernel):
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE) rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
for name, module in model.named_modules(): for _, module in model.named_modules():
# Match any module whose class name contains "RMSNorm"
if re.search(rms_norm_pattern, module.__class__.__name__): if re.search(rms_norm_pattern, module.__class__.__name__):
# Bind function as an instance method to preserve `self` semantics if "Gated" in module.__class__.__name__:
# and replace the original forward module.forward = types.MethodType(npu_gated_rms_norm_forward, module)
else:
module.forward = types.MethodType(npu_rms_norm_forward, module) module.forward = types.MethodType(npu_rms_norm_forward, module)
return model return model

View File

@@ -39,40 +39,80 @@ except ImportError:
pass pass
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): def _apply_npu_rotary_emb(q, k, cos, sin):
"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization. """Apply NPU-accelerated rotary embedding with automatic Partial RoPE detection.
This function automatically detects whether to use Partial RoPE or Full RoPE
based on the dimension ratio between ``cos/sin`` and ``q/k`` tensors, ensuring
compatibility with future model versions without hardcoding.
Args: Args:
q (Tensor): Query tensor. q (Tensor): Query tensor.
k (Tensor): Key tensor. k (Tensor): Key tensor.
cos (Tensor): Cosine part of embedding. cos (Tensor): Cosine part of rotary embedding (already unsqueezed).
sin (Tensor): Sine part of embedding. sin (Tensor): Sine part of rotary embedding (already unsqueezed).
position_ids (Tensor, optional): Position IDs. Default: ``None``.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
Returns: Returns:
tuple: (q_embed, k_embed) The embedded query and key tensors. tuple[Tensor, Tensor]: The embedded query and key tensors ``(q_embed, k_embed)``.
""" """
cos = cos.unsqueeze(unsqueeze_dim) rotary_dim = cos.shape[-1]
sin = sin.unsqueeze(unsqueeze_dim) query_dim = q.shape[-1]
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin) if rotary_dim < query_dim:
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
q_embed = torch_npu.npu_rotary_mul(q_rot, cos, sin).to(q.dtype)
k_embed = torch_npu.npu_rotary_mul(k_rot, cos, sin).to(k.dtype)
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
else:
q_embed = torch_npu.npu_rotary_mul(q, cos, sin).to(q.dtype)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin).to(k.dtype)
return q_embed, k_embed return q_embed, k_embed
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1): def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU. """Apply Rotary Position Embedding to query and key tensors using NPU optimization.
This function automatically supports both Full RoPE and Partial RoPE based on
the dimension ratio between ``cos/sin`` and ``q/k`` tensors.
Args: Args:
q (Tensor): Query tensor. q (Tensor): Query tensor.
k (Tensor): Key tensor. k (Tensor): Key tensor.
cos (Tensor): Cosine part of embedding. cos (Tensor): Cosine part of embedding.
sin (Tensor): Sine part of embedding. sin (Tensor): Sine part of embedding.
mrope_section (Tensor): Multimodal RoPE section. position_ids (Tensor, optional): Position IDs. Defaults to ``None``.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1. unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Defaults to 1.
Returns: Returns:
tuple: (q_embed, k_embed) The embedded query and key tensors. tuple[Tensor, Tensor]: The embedded query and key tensors ``(q_embed, k_embed)``.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
return _apply_npu_rotary_emb(q, k, cos, sin)
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""Apply Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
This function supports Partial RoPE for multimodal inputs with automatic dimension
detection, ensuring compatibility with future model versions.
Args:
q (Tensor): Query tensor.
k (Tensor): Key tensor.
cos (Tensor): Cosine part of embedding.
sin (Tensor): Sine part of embedding.
mrope_section (list[int]): Multimodal RoPE section sizes.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Defaults to 1.
Returns:
tuple[Tensor, Tensor]: The embedded query and key tensors ``(q_embed, k_embed)``.
""" """
mrope_section = mrope_section * 2 mrope_section = mrope_section * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
@@ -82,9 +122,7 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
unsqueeze_dim unsqueeze_dim
) )
q_embed = torch_npu.npu_rotary_mul(q, cos, sin) return _apply_npu_rotary_emb(q, k, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
@register_kernel @register_kernel
@@ -96,12 +134,12 @@ class NpuRoPEKernel(BaseKernel):
@classmethod @classmethod
def apply(cls, **kwargs) -> "HFModel": def apply(cls, **kwargs) -> "HFModel":
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`. """Apply RoPE acceleration by monkey-patching ``apply_rotary_pos_emb``.
This function iterates through the model's modules to find attention layers, Iterates through the model's modules to find attention layers, identifies
identifies the module where they are defined, and replaces the original the module where they are defined, and replaces the original
`apply_rotary_pos_emb` function in that module's namespace with the ``apply_rotary_pos_emb`` function in that module's namespace with the
NPU-accelerated version from this file. NPU-accelerated version.
Args: Args:
**kwargs: Keyword arguments containing the model. **kwargs: Keyword arguments containing the model.
@@ -110,7 +148,7 @@ class NpuRoPEKernel(BaseKernel):
HFModel: The model with patched RoPE functions. HFModel: The model with patched RoPE functions.
Raises: Raises:
RuntimeError: If dependencies are not met. RuntimeError: If ``torch_npu`` is not available.
ValueError: If the model is not provided. ValueError: If the model is not provided.
""" """
if not cls.check_deps(): if not cls.check_deps():