mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-05 07:38:55 +08:00
[npu] add Qwen3.5 support with Partial RoPE and Hybrid Attention (#10421)
Co-authored-by: Curnane <mingliangfu@users.noreply.github.com>
This commit is contained in:
20
examples/accelerate/fsdp2_config_qwen35.yaml
Normal file
20
examples/accelerate/fsdp2_config_qwen35.yaml
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
distributed_type: FSDP
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer,Qwen3_5VisionBlock
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_reshard_after_forward: true
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 8 # Change to match your NPU count (e.g., 8 for A2, 16 for A3)
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
use_cpu: false
|
||||||
47
examples/ascend/qwen3_5_full_sft_fsdp2.yaml
Normal file
47
examples/ascend/qwen3_5_full_sft_fsdp2.yaml
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# Start FSDP2 full fine-tuning on Ascend NPU
|
||||||
|
# Usage:
|
||||||
|
# accelerate launch \
|
||||||
|
# --config_file examples/accelerate/fsdp2_config_qwen35.yaml \
|
||||||
|
# src/train.py examples/ascend/qwen3_5_full_sft_fsdp2.yaml
|
||||||
|
#
|
||||||
|
# Note: Change `num_processes` in fsdp2_config_qwen35.yaml to match your NPU count
|
||||||
|
|
||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen3.5-4B
|
||||||
|
trust_remote_code: true
|
||||||
|
use_v1_kernels: true
|
||||||
|
flash_attn: fa2
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: alpaca_en_demo
|
||||||
|
template: qwen3_5_nothink
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/Qwen3.5-4B/full/sft
|
||||||
|
logging_steps: 1
|
||||||
|
save_steps: 500
|
||||||
|
max_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 8
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
learning_rate: 1.0e-5
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 1800
|
||||||
|
resume_from_checkpoint: null
|
||||||
@@ -23,25 +23,104 @@ Init Phase:
|
|||||||
import re
|
import re
|
||||||
import types
|
import types
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ......accelerator.helper import DeviceType
|
from ......accelerator.helper import DeviceType
|
||||||
from ......utils.types import HFModel
|
from ......utils.types import HFModel
|
||||||
from ...base import BaseKernel
|
from ...base import BaseKernel
|
||||||
from ...registry import register_kernel
|
from ...registry import register_kernel
|
||||||
|
|
||||||
|
|
||||||
def npu_rms_norm_forward(self, hidden_states):
|
try:
|
||||||
"""NPU forward implementation for RMSNorm.
|
import torch_npu
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _should_use_residual_rmsnorm(module):
|
||||||
|
"""Detect whether the module uses residual RMSNorm parameterization.
|
||||||
|
|
||||||
|
Residual RMSNorm uses ``scale = 1.0 + weight`` where weight is initialized to 0,
|
||||||
|
while standard RMSNorm uses ``scale = weight`` where weight is initialized to 1.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
self: RMSNorm module instance with `weight` and `variance_epsilon`.
|
module (nn.Module): The RMSNorm module to check.
|
||||||
hidden_states (Tensor): Input hidden states tensor, same shape as the baseline.
|
|
||||||
|
Returns:
|
||||||
|
bool: ``True`` if the module uses residual parameterization, ``False`` otherwise.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This detection ensures compatibility with future model versions (e.g., Qwen3.6, Qwen4.0)
|
||||||
|
without hardcoding version numbers. Two methods are used: weight value inspection
|
||||||
|
(most reliable) and class name pattern matching (backward compatibility).
|
||||||
|
"""
|
||||||
|
if hasattr(module, "weight") and module.weight is not None:
|
||||||
|
weight_mean = module.weight.data.mean().item()
|
||||||
|
if abs(weight_mean) < 0.3:
|
||||||
|
return True
|
||||||
|
|
||||||
|
class_name = module.__class__.__name__
|
||||||
|
residual_patterns = ["Qwen3_5", "Qwen3_6", "Qwen4"]
|
||||||
|
for pattern in residual_patterns:
|
||||||
|
if pattern in class_name:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def npu_rms_norm_forward(self, hidden_states):
|
||||||
|
"""NPU forward implementation for standard RMSNorm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
self (nn.Module): The RMSNorm module instance with ``weight`` and ``variance_epsilon``.
|
||||||
|
hidden_states (Tensor): Input hidden states tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: Normalized tensor consistent with the baseline RMSNorm behavior.
|
Tensor: Normalized tensor consistent with the baseline RMSNorm behavior.
|
||||||
"""
|
"""
|
||||||
import torch_npu
|
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
|
||||||
|
|
||||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
if hasattr(self, "weight") and self.weight is not None:
|
||||||
|
if _should_use_residual_rmsnorm(self):
|
||||||
|
effective_weight = 1.0 + self.weight.float()
|
||||||
|
else:
|
||||||
|
effective_weight = self.weight.float()
|
||||||
|
else:
|
||||||
|
effective_weight = None
|
||||||
|
|
||||||
|
if effective_weight is not None:
|
||||||
|
return torch_npu.npu_rms_norm(hidden_states, effective_weight.to(hidden_states.dtype), epsilon=_eps)[0]
|
||||||
|
else:
|
||||||
|
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=_eps)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def npu_gated_rms_norm_forward(self, hidden_states, gate=None):
|
||||||
|
"""NPU forward implementation for Gated RMSNorm with high-precision FP32 computation.
|
||||||
|
|
||||||
|
This function performs RMSNorm and gated SiLU multiplication in FP32 for numerical
|
||||||
|
stability. Unlike standard RMSNorm, Gated RMSNorm in Qwen3.5 uses standard
|
||||||
|
parameterization (``scale = weight`` where weight is initialized to 1), so the
|
||||||
|
residual weight adjustment (``1.0 + weight``) is not applied here.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
self (nn.Module): The Gated RMSNorm module instance.
|
||||||
|
hidden_states (Tensor): Input hidden states tensor.
|
||||||
|
gate (Tensor, optional): Gate tensor for SiLU activation. Defaults to ``None``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor cast back to the original input dtype.
|
||||||
|
"""
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
|
||||||
|
|
||||||
|
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight.float(), epsilon=_eps)[0]
|
||||||
|
|
||||||
|
if gate is not None:
|
||||||
|
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
|
||||||
|
|
||||||
|
return hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
@register_kernel
|
@register_kernel
|
||||||
@@ -55,12 +134,9 @@ class NpuRMSNormKernel(BaseKernel):
|
|||||||
def apply(cls, **kwargs) -> "HFModel":
|
def apply(cls, **kwargs) -> "HFModel":
|
||||||
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||||
|
|
||||||
Key points:
|
Matches modules whose class name contains "RMSNorm" (case-insensitive) and binds
|
||||||
- Match modules whose class name contains "RMSNorm" (case-insensitive).
|
the appropriate NPU-optimized forward function as an instance method via
|
||||||
- Bind `_npu_rms_forward` as an instance method via `types.MethodType` to
|
``types.MethodType`` to replace the original ``forward``.
|
||||||
replace the original `forward`.
|
|
||||||
- Do not modify weights, hyperparameters, or module structure to ensure
|
|
||||||
numerical behavior and interface consistency.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs: Keyword arguments containing the model.
|
**kwargs: Keyword arguments containing the model.
|
||||||
@@ -69,7 +145,7 @@ class NpuRMSNormKernel(BaseKernel):
|
|||||||
HFModel: The model with NPU fused RMSNorm.
|
HFModel: The model with NPU fused RMSNorm.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If torch_npu is not available.
|
RuntimeError: If ``torch_npu`` is not available.
|
||||||
ValueError: If the model is not provided.
|
ValueError: If the model is not provided.
|
||||||
"""
|
"""
|
||||||
model = kwargs.get("model")
|
model = kwargs.get("model")
|
||||||
@@ -81,11 +157,11 @@ class NpuRMSNormKernel(BaseKernel):
|
|||||||
|
|
||||||
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
|
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for _, module in model.named_modules():
|
||||||
# Match any module whose class name contains "RMSNorm"
|
|
||||||
if re.search(rms_norm_pattern, module.__class__.__name__):
|
if re.search(rms_norm_pattern, module.__class__.__name__):
|
||||||
# Bind function as an instance method to preserve `self` semantics
|
if "Gated" in module.__class__.__name__:
|
||||||
# and replace the original forward
|
module.forward = types.MethodType(npu_gated_rms_norm_forward, module)
|
||||||
|
else:
|
||||||
module.forward = types.MethodType(npu_rms_norm_forward, module)
|
module.forward = types.MethodType(npu_rms_norm_forward, module)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -39,40 +39,80 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
def _apply_npu_rotary_emb(q, k, cos, sin):
|
||||||
"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
|
"""Apply NPU-accelerated rotary embedding with automatic Partial RoPE detection.
|
||||||
|
|
||||||
|
This function automatically detects whether to use Partial RoPE or Full RoPE
|
||||||
|
based on the dimension ratio between ``cos/sin`` and ``q/k`` tensors, ensuring
|
||||||
|
compatibility with future model versions without hardcoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q (Tensor): Query tensor.
|
q (Tensor): Query tensor.
|
||||||
k (Tensor): Key tensor.
|
k (Tensor): Key tensor.
|
||||||
cos (Tensor): Cosine part of embedding.
|
cos (Tensor): Cosine part of rotary embedding (already unsqueezed).
|
||||||
sin (Tensor): Sine part of embedding.
|
sin (Tensor): Sine part of rotary embedding (already unsqueezed).
|
||||||
position_ids (Tensor, optional): Position IDs. Default: ``None``.
|
|
||||||
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (q_embed, k_embed) The embedded query and key tensors.
|
tuple[Tensor, Tensor]: The embedded query and key tensors ``(q_embed, k_embed)``.
|
||||||
"""
|
"""
|
||||||
cos = cos.unsqueeze(unsqueeze_dim)
|
rotary_dim = cos.shape[-1]
|
||||||
sin = sin.unsqueeze(unsqueeze_dim)
|
query_dim = q.shape[-1]
|
||||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
|
||||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
if rotary_dim < query_dim:
|
||||||
|
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
|
||||||
|
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
|
||||||
|
|
||||||
|
q_embed = torch_npu.npu_rotary_mul(q_rot, cos, sin).to(q.dtype)
|
||||||
|
k_embed = torch_npu.npu_rotary_mul(k_rot, cos, sin).to(k.dtype)
|
||||||
|
|
||||||
|
q_embed = torch.cat([q_embed, q_pass], dim=-1)
|
||||||
|
k_embed = torch.cat([k_embed, k_pass], dim=-1)
|
||||||
|
else:
|
||||||
|
q_embed = torch_npu.npu_rotary_mul(q, cos, sin).to(q.dtype)
|
||||||
|
k_embed = torch_npu.npu_rotary_mul(k, cos, sin).to(k.dtype)
|
||||||
|
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
|
"""Apply Rotary Position Embedding to query and key tensors using NPU optimization.
|
||||||
|
|
||||||
|
This function automatically supports both Full RoPE and Partial RoPE based on
|
||||||
|
the dimension ratio between ``cos/sin`` and ``q/k`` tensors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q (Tensor): Query tensor.
|
q (Tensor): Query tensor.
|
||||||
k (Tensor): Key tensor.
|
k (Tensor): Key tensor.
|
||||||
cos (Tensor): Cosine part of embedding.
|
cos (Tensor): Cosine part of embedding.
|
||||||
sin (Tensor): Sine part of embedding.
|
sin (Tensor): Sine part of embedding.
|
||||||
mrope_section (Tensor): Multimodal RoPE section.
|
position_ids (Tensor, optional): Position IDs. Defaults to ``None``.
|
||||||
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
|
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Defaults to 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (q_embed, k_embed) The embedded query and key tensors.
|
tuple[Tensor, Tensor]: The embedded query and key tensors ``(q_embed, k_embed)``.
|
||||||
|
"""
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
|
||||||
|
return _apply_npu_rotary_emb(q, k, cos, sin)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||||
|
"""Apply Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
|
||||||
|
|
||||||
|
This function supports Partial RoPE for multimodal inputs with automatic dimension
|
||||||
|
detection, ensuring compatibility with future model versions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (Tensor): Query tensor.
|
||||||
|
k (Tensor): Key tensor.
|
||||||
|
cos (Tensor): Cosine part of embedding.
|
||||||
|
sin (Tensor): Sine part of embedding.
|
||||||
|
mrope_section (list[int]): Multimodal RoPE section sizes.
|
||||||
|
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Defaults to 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[Tensor, Tensor]: The embedded query and key tensors ``(q_embed, k_embed)``.
|
||||||
"""
|
"""
|
||||||
mrope_section = mrope_section * 2
|
mrope_section = mrope_section * 2
|
||||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||||
@@ -82,9 +122,7 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
|
|||||||
unsqueeze_dim
|
unsqueeze_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
return _apply_npu_rotary_emb(q, k, cos, sin)
|
||||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
@register_kernel
|
@register_kernel
|
||||||
@@ -96,12 +134,12 @@ class NpuRoPEKernel(BaseKernel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, **kwargs) -> "HFModel":
|
def apply(cls, **kwargs) -> "HFModel":
|
||||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
"""Apply RoPE acceleration by monkey-patching ``apply_rotary_pos_emb``.
|
||||||
|
|
||||||
This function iterates through the model's modules to find attention layers,
|
Iterates through the model's modules to find attention layers, identifies
|
||||||
identifies the module where they are defined, and replaces the original
|
the module where they are defined, and replaces the original
|
||||||
`apply_rotary_pos_emb` function in that module's namespace with the
|
``apply_rotary_pos_emb`` function in that module's namespace with the
|
||||||
NPU-accelerated version from this file.
|
NPU-accelerated version.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs: Keyword arguments containing the model.
|
**kwargs: Keyword arguments containing the model.
|
||||||
@@ -110,7 +148,7 @@ class NpuRoPEKernel(BaseKernel):
|
|||||||
HFModel: The model with patched RoPE functions.
|
HFModel: The model with patched RoPE functions.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If dependencies are not met.
|
RuntimeError: If ``torch_npu`` is not available.
|
||||||
ValueError: If the model is not provided.
|
ValueError: If the model is not provided.
|
||||||
"""
|
"""
|
||||||
if not cls.check_deps():
|
if not cls.check_deps():
|
||||||
|
|||||||
Reference in New Issue
Block a user