[npu] add Qwen3.5 support with Partial RoPE and Hybrid Attention (#10421)

Co-authored-by: Curnane <mingliangfu@users.noreply.github.com>
This commit is contained in:
curnane-lab
2026-04-27 23:36:07 +08:00
committed by GitHub
parent 99464b3d03
commit 2092abc217
4 changed files with 224 additions and 43 deletions

View File

@@ -0,0 +1,20 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_version: 2
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer,Qwen3_5VisionBlock
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8 # Change to match your NPU count (e.g., 8 for A2, 16 for A3)
rdzv_backend: static
same_network: true
use_cpu: false

View File

@@ -0,0 +1,47 @@
# Start FSDP2 full fine-tuning on Ascend NPU
# Usage:
# accelerate launch \
# --config_file examples/accelerate/fsdp2_config_qwen35.yaml \
# src/train.py examples/ascend/qwen3_5_full_sft_fsdp2.yaml
#
# Note: Change `num_processes` in fsdp2_config_qwen35.yaml to match your NPU count
### model
model_name_or_path: Qwen/Qwen3.5-4B
trust_remote_code: true
use_v1_kernels: true
flash_attn: fa2
### method
stage: sft
do_train: true
finetuning_type: full
### dataset
dataset: alpaca_en_demo
template: qwen3_5_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Qwen3.5-4B/full/sft
logging_steps: 1
save_steps: 500
max_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 8
gradient_accumulation_steps: 1
learning_rate: 1.0e-5
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 1800
resume_from_checkpoint: null

View File

@@ -23,25 +23,104 @@ Init Phase:
import re
import types
import torch
import torch.nn.functional as F
from ......accelerator.helper import DeviceType
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
def npu_rms_norm_forward(self, hidden_states):
"""NPU forward implementation for RMSNorm.
try:
import torch_npu
except ImportError:
pass
def _should_use_residual_rmsnorm(module):
"""Detect whether the module uses residual RMSNorm parameterization.
Residual RMSNorm uses ``scale = 1.0 + weight`` where weight is initialized to 0,
while standard RMSNorm uses ``scale = weight`` where weight is initialized to 1.
Args:
self: RMSNorm module instance with `weight` and `variance_epsilon`.
hidden_states (Tensor): Input hidden states tensor, same shape as the baseline.
module (nn.Module): The RMSNorm module to check.
Returns:
bool: ``True`` if the module uses residual parameterization, ``False`` otherwise.
.. note::
This detection ensures compatibility with future model versions (e.g., Qwen3.6, Qwen4.0)
without hardcoding version numbers. Two methods are used: weight value inspection
(most reliable) and class name pattern matching (backward compatibility).
"""
if hasattr(module, "weight") and module.weight is not None:
weight_mean = module.weight.data.mean().item()
if abs(weight_mean) < 0.3:
return True
class_name = module.__class__.__name__
residual_patterns = ["Qwen3_5", "Qwen3_6", "Qwen4"]
for pattern in residual_patterns:
if pattern in class_name:
return True
return False
def npu_rms_norm_forward(self, hidden_states):
"""NPU forward implementation for standard RMSNorm.
Args:
self (nn.Module): The RMSNorm module instance with ``weight`` and ``variance_epsilon``.
hidden_states (Tensor): Input hidden states tensor.
Returns:
Tensor: Normalized tensor consistent with the baseline RMSNorm behavior.
"""
import torch_npu
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
if hasattr(self, "weight") and self.weight is not None:
if _should_use_residual_rmsnorm(self):
effective_weight = 1.0 + self.weight.float()
else:
effective_weight = self.weight.float()
else:
effective_weight = None
if effective_weight is not None:
return torch_npu.npu_rms_norm(hidden_states, effective_weight.to(hidden_states.dtype), epsilon=_eps)[0]
else:
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=_eps)[0]
def npu_gated_rms_norm_forward(self, hidden_states, gate=None):
"""NPU forward implementation for Gated RMSNorm with high-precision FP32 computation.
This function performs RMSNorm and gated SiLU multiplication in FP32 for numerical
stability. Unlike standard RMSNorm, Gated RMSNorm in Qwen3.5 uses standard
parameterization (``scale = weight`` where weight is initialized to 1), so the
residual weight adjustment (``1.0 + weight``) is not applied here.
Args:
self (nn.Module): The Gated RMSNorm module instance.
hidden_states (Tensor): Input hidden states tensor.
gate (Tensor, optional): Gate tensor for SiLU activation. Defaults to ``None``.
Returns:
Tensor: Output tensor cast back to the original input dtype.
"""
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight.float(), epsilon=_eps)[0]
if gate is not None:
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype)
@register_kernel
@@ -55,12 +134,9 @@ class NpuRMSNormKernel(BaseKernel):
def apply(cls, **kwargs) -> "HFModel":
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
Key points:
- Match modules whose class name contains "RMSNorm" (case-insensitive).
- Bind `_npu_rms_forward` as an instance method via `types.MethodType` to
replace the original `forward`.
- Do not modify weights, hyperparameters, or module structure to ensure
numerical behavior and interface consistency.
Matches modules whose class name contains "RMSNorm" (case-insensitive) and binds
the appropriate NPU-optimized forward function as an instance method via
``types.MethodType`` to replace the original ``forward``.
Args:
**kwargs: Keyword arguments containing the model.
@@ -69,7 +145,7 @@ class NpuRMSNormKernel(BaseKernel):
HFModel: The model with NPU fused RMSNorm.
Raises:
RuntimeError: If torch_npu is not available.
RuntimeError: If ``torch_npu`` is not available.
ValueError: If the model is not provided.
"""
model = kwargs.get("model")
@@ -81,11 +157,11 @@ class NpuRMSNormKernel(BaseKernel):
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
for name, module in model.named_modules():
# Match any module whose class name contains "RMSNorm"
for _, module in model.named_modules():
if re.search(rms_norm_pattern, module.__class__.__name__):
# Bind function as an instance method to preserve `self` semantics
# and replace the original forward
if "Gated" in module.__class__.__name__:
module.forward = types.MethodType(npu_gated_rms_norm_forward, module)
else:
module.forward = types.MethodType(npu_rms_norm_forward, module)
return model

View File

@@ -39,40 +39,80 @@ except ImportError:
pass
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
def _apply_npu_rotary_emb(q, k, cos, sin):
"""Apply NPU-accelerated rotary embedding with automatic Partial RoPE detection.
This function automatically detects whether to use Partial RoPE or Full RoPE
based on the dimension ratio between ``cos/sin`` and ``q/k`` tensors, ensuring
compatibility with future model versions without hardcoding.
Args:
q (Tensor): Query tensor.
k (Tensor): Key tensor.
cos (Tensor): Cosine part of embedding.
sin (Tensor): Sine part of embedding.
position_ids (Tensor, optional): Position IDs. Default: ``None``.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
cos (Tensor): Cosine part of rotary embedding (already unsqueezed).
sin (Tensor): Sine part of rotary embedding (already unsqueezed).
Returns:
tuple: (q_embed, k_embed) The embedded query and key tensors.
tuple[Tensor, Tensor]: The embedded query and key tensors ``(q_embed, k_embed)``.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
rotary_dim = cos.shape[-1]
query_dim = q.shape[-1]
if rotary_dim < query_dim:
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
q_embed = torch_npu.npu_rotary_mul(q_rot, cos, sin).to(q.dtype)
k_embed = torch_npu.npu_rotary_mul(k_rot, cos, sin).to(k.dtype)
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
else:
q_embed = torch_npu.npu_rotary_mul(q, cos, sin).to(q.dtype)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin).to(k.dtype)
return q_embed, k_embed
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Apply Rotary Position Embedding to query and key tensors using NPU optimization.
This function automatically supports both Full RoPE and Partial RoPE based on
the dimension ratio between ``cos/sin`` and ``q/k`` tensors.
Args:
q (Tensor): Query tensor.
k (Tensor): Key tensor.
cos (Tensor): Cosine part of embedding.
sin (Tensor): Sine part of embedding.
mrope_section (Tensor): Multimodal RoPE section.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
position_ids (Tensor, optional): Position IDs. Defaults to ``None``.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Defaults to 1.
Returns:
tuple: (q_embed, k_embed) The embedded query and key tensors.
tuple[Tensor, Tensor]: The embedded query and key tensors ``(q_embed, k_embed)``.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
return _apply_npu_rotary_emb(q, k, cos, sin)
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""Apply Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
This function supports Partial RoPE for multimodal inputs with automatic dimension
detection, ensuring compatibility with future model versions.
Args:
q (Tensor): Query tensor.
k (Tensor): Key tensor.
cos (Tensor): Cosine part of embedding.
sin (Tensor): Sine part of embedding.
mrope_section (list[int]): Multimodal RoPE section sizes.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Defaults to 1.
Returns:
tuple[Tensor, Tensor]: The embedded query and key tensors ``(q_embed, k_embed)``.
"""
mrope_section = mrope_section * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
@@ -82,9 +122,7 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
unsqueeze_dim
)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
return _apply_npu_rotary_emb(q, k, cos, sin)
@register_kernel
@@ -96,12 +134,12 @@ class NpuRoPEKernel(BaseKernel):
@classmethod
def apply(cls, **kwargs) -> "HFModel":
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
"""Apply RoPE acceleration by monkey-patching ``apply_rotary_pos_emb``.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
`apply_rotary_pos_emb` function in that module's namespace with the
NPU-accelerated version from this file.
Iterates through the model's modules to find attention layers, identifies
the module where they are defined, and replaces the original
``apply_rotary_pos_emb`` function in that module's namespace with the
NPU-accelerated version.
Args:
**kwargs: Keyword arguments containing the model.
@@ -110,7 +148,7 @@ class NpuRoPEKernel(BaseKernel):
HFModel: The model with patched RoPE functions.
Raises:
RuntimeError: If dependencies are not met.
RuntimeError: If ``torch_npu`` is not available.
ValueError: If the model is not provided.
"""
if not cls.check_deps():