mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-05 07:38:55 +08:00
[npu] add Qwen3.5 support with Partial RoPE and Hybrid Attention (#10421)
Co-authored-by: Curnane <mingliangfu@users.noreply.github.com>
This commit is contained in:
20
examples/accelerate/fsdp2_config_qwen35.yaml
Normal file
20
examples/accelerate/fsdp2_config_qwen35.yaml
Normal file
@@ -0,0 +1,20 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer,Qwen3_5VisionBlock
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8 # Change to match your NPU count (e.g., 8 for A2, 16 for A3)
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
use_cpu: false
|
||||
47
examples/ascend/qwen3_5_full_sft_fsdp2.yaml
Normal file
47
examples/ascend/qwen3_5_full_sft_fsdp2.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
# Start FSDP2 full fine-tuning on Ascend NPU
|
||||
# Usage:
|
||||
# accelerate launch \
|
||||
# --config_file examples/accelerate/fsdp2_config_qwen35.yaml \
|
||||
# src/train.py examples/ascend/qwen3_5_full_sft_fsdp2.yaml
|
||||
#
|
||||
# Note: Change `num_processes` in fsdp2_config_qwen35.yaml to match your NPU count
|
||||
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen3.5-4B
|
||||
trust_remote_code: true
|
||||
use_v1_kernels: true
|
||||
flash_attn: fa2
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
|
||||
### dataset
|
||||
dataset: alpaca_en_demo
|
||||
template: qwen3_5_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/Qwen3.5-4B/full/sft
|
||||
logging_steps: 1
|
||||
save_steps: 500
|
||||
max_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 8
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 1.0e-5
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 1800
|
||||
resume_from_checkpoint: null
|
||||
@@ -23,25 +23,104 @@ Init Phase:
|
||||
import re
|
||||
import types
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ......accelerator.helper import DeviceType
|
||||
from ......utils.types import HFModel
|
||||
from ...base import BaseKernel
|
||||
from ...registry import register_kernel
|
||||
|
||||
|
||||
def npu_rms_norm_forward(self, hidden_states):
|
||||
"""NPU forward implementation for RMSNorm.
|
||||
try:
|
||||
import torch_npu
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _should_use_residual_rmsnorm(module):
|
||||
"""Detect whether the module uses residual RMSNorm parameterization.
|
||||
|
||||
Residual RMSNorm uses ``scale = 1.0 + weight`` where weight is initialized to 0,
|
||||
while standard RMSNorm uses ``scale = weight`` where weight is initialized to 1.
|
||||
|
||||
Args:
|
||||
self: RMSNorm module instance with `weight` and `variance_epsilon`.
|
||||
hidden_states (Tensor): Input hidden states tensor, same shape as the baseline.
|
||||
module (nn.Module): The RMSNorm module to check.
|
||||
|
||||
Returns:
|
||||
bool: ``True`` if the module uses residual parameterization, ``False`` otherwise.
|
||||
|
||||
.. note::
|
||||
This detection ensures compatibility with future model versions (e.g., Qwen3.6, Qwen4.0)
|
||||
without hardcoding version numbers. Two methods are used: weight value inspection
|
||||
(most reliable) and class name pattern matching (backward compatibility).
|
||||
"""
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
weight_mean = module.weight.data.mean().item()
|
||||
if abs(weight_mean) < 0.3:
|
||||
return True
|
||||
|
||||
class_name = module.__class__.__name__
|
||||
residual_patterns = ["Qwen3_5", "Qwen3_6", "Qwen4"]
|
||||
for pattern in residual_patterns:
|
||||
if pattern in class_name:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def npu_rms_norm_forward(self, hidden_states):
|
||||
"""NPU forward implementation for standard RMSNorm.
|
||||
|
||||
Args:
|
||||
self (nn.Module): The RMSNorm module instance with ``weight`` and ``variance_epsilon``.
|
||||
hidden_states (Tensor): Input hidden states tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: Normalized tensor consistent with the baseline RMSNorm behavior.
|
||||
"""
|
||||
import torch_npu
|
||||
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
|
||||
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||
if hasattr(self, "weight") and self.weight is not None:
|
||||
if _should_use_residual_rmsnorm(self):
|
||||
effective_weight = 1.0 + self.weight.float()
|
||||
else:
|
||||
effective_weight = self.weight.float()
|
||||
else:
|
||||
effective_weight = None
|
||||
|
||||
if effective_weight is not None:
|
||||
return torch_npu.npu_rms_norm(hidden_states, effective_weight.to(hidden_states.dtype), epsilon=_eps)[0]
|
||||
else:
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=_eps)[0]
|
||||
|
||||
|
||||
def npu_gated_rms_norm_forward(self, hidden_states, gate=None):
|
||||
"""NPU forward implementation for Gated RMSNorm with high-precision FP32 computation.
|
||||
|
||||
This function performs RMSNorm and gated SiLU multiplication in FP32 for numerical
|
||||
stability. Unlike standard RMSNorm, Gated RMSNorm in Qwen3.5 uses standard
|
||||
parameterization (``scale = weight`` where weight is initialized to 1), so the
|
||||
residual weight adjustment (``1.0 + weight``) is not applied here.
|
||||
|
||||
Args:
|
||||
self (nn.Module): The Gated RMSNorm module instance.
|
||||
hidden_states (Tensor): Input hidden states tensor.
|
||||
gate (Tensor, optional): Gate tensor for SiLU activation. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor cast back to the original input dtype.
|
||||
"""
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
|
||||
|
||||
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight.float(), epsilon=_eps)[0]
|
||||
|
||||
if gate is not None:
|
||||
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
|
||||
|
||||
return hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
@register_kernel
|
||||
@@ -55,12 +134,9 @@ class NpuRMSNormKernel(BaseKernel):
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||
|
||||
Key points:
|
||||
- Match modules whose class name contains "RMSNorm" (case-insensitive).
|
||||
- Bind `_npu_rms_forward` as an instance method via `types.MethodType` to
|
||||
replace the original `forward`.
|
||||
- Do not modify weights, hyperparameters, or module structure to ensure
|
||||
numerical behavior and interface consistency.
|
||||
Matches modules whose class name contains "RMSNorm" (case-insensitive) and binds
|
||||
the appropriate NPU-optimized forward function as an instance method via
|
||||
``types.MethodType`` to replace the original ``forward``.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
@@ -69,7 +145,7 @@ class NpuRMSNormKernel(BaseKernel):
|
||||
HFModel: The model with NPU fused RMSNorm.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If torch_npu is not available.
|
||||
RuntimeError: If ``torch_npu`` is not available.
|
||||
ValueError: If the model is not provided.
|
||||
"""
|
||||
model = kwargs.get("model")
|
||||
@@ -81,11 +157,11 @@ class NpuRMSNormKernel(BaseKernel):
|
||||
|
||||
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
# Match any module whose class name contains "RMSNorm"
|
||||
for _, module in model.named_modules():
|
||||
if re.search(rms_norm_pattern, module.__class__.__name__):
|
||||
# Bind function as an instance method to preserve `self` semantics
|
||||
# and replace the original forward
|
||||
module.forward = types.MethodType(npu_rms_norm_forward, module)
|
||||
if "Gated" in module.__class__.__name__:
|
||||
module.forward = types.MethodType(npu_gated_rms_norm_forward, module)
|
||||
else:
|
||||
module.forward = types.MethodType(npu_rms_norm_forward, module)
|
||||
|
||||
return model
|
||||
|
||||
@@ -39,40 +39,80 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
|
||||
def _apply_npu_rotary_emb(q, k, cos, sin):
|
||||
"""Apply NPU-accelerated rotary embedding with automatic Partial RoPE detection.
|
||||
|
||||
This function automatically detects whether to use Partial RoPE or Full RoPE
|
||||
based on the dimension ratio between ``cos/sin`` and ``q/k`` tensors, ensuring
|
||||
compatibility with future model versions without hardcoding.
|
||||
|
||||
Args:
|
||||
q (Tensor): Query tensor.
|
||||
k (Tensor): Key tensor.
|
||||
cos (Tensor): Cosine part of embedding.
|
||||
sin (Tensor): Sine part of embedding.
|
||||
position_ids (Tensor, optional): Position IDs. Default: ``None``.
|
||||
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
|
||||
cos (Tensor): Cosine part of rotary embedding (already unsqueezed).
|
||||
sin (Tensor): Sine part of rotary embedding (already unsqueezed).
|
||||
|
||||
Returns:
|
||||
tuple: (q_embed, k_embed) The embedded query and key tensors.
|
||||
tuple[Tensor, Tensor]: The embedded query and key tensors ``(q_embed, k_embed)``.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
rotary_dim = cos.shape[-1]
|
||||
query_dim = q.shape[-1]
|
||||
|
||||
if rotary_dim < query_dim:
|
||||
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
|
||||
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
|
||||
|
||||
q_embed = torch_npu.npu_rotary_mul(q_rot, cos, sin).to(q.dtype)
|
||||
k_embed = torch_npu.npu_rotary_mul(k_rot, cos, sin).to(k.dtype)
|
||||
|
||||
q_embed = torch.cat([q_embed, q_pass], dim=-1)
|
||||
k_embed = torch.cat([k_embed, k_pass], dim=-1)
|
||||
else:
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin).to(q.dtype)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin).to(k.dtype)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
|
||||
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Apply Rotary Position Embedding to query and key tensors using NPU optimization.
|
||||
|
||||
This function automatically supports both Full RoPE and Partial RoPE based on
|
||||
the dimension ratio between ``cos/sin`` and ``q/k`` tensors.
|
||||
|
||||
Args:
|
||||
q (Tensor): Query tensor.
|
||||
k (Tensor): Key tensor.
|
||||
cos (Tensor): Cosine part of embedding.
|
||||
sin (Tensor): Sine part of embedding.
|
||||
mrope_section (Tensor): Multimodal RoPE section.
|
||||
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
|
||||
position_ids (Tensor, optional): Position IDs. Defaults to ``None``.
|
||||
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
tuple: (q_embed, k_embed) The embedded query and key tensors.
|
||||
tuple[Tensor, Tensor]: The embedded query and key tensors ``(q_embed, k_embed)``.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
|
||||
return _apply_npu_rotary_emb(q, k, cos, sin)
|
||||
|
||||
|
||||
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||
"""Apply Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
|
||||
|
||||
This function supports Partial RoPE for multimodal inputs with automatic dimension
|
||||
detection, ensuring compatibility with future model versions.
|
||||
|
||||
Args:
|
||||
q (Tensor): Query tensor.
|
||||
k (Tensor): Key tensor.
|
||||
cos (Tensor): Cosine part of embedding.
|
||||
sin (Tensor): Sine part of embedding.
|
||||
mrope_section (list[int]): Multimodal RoPE section sizes.
|
||||
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: The embedded query and key tensors ``(q_embed, k_embed)``.
|
||||
"""
|
||||
mrope_section = mrope_section * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
@@ -82,9 +122,7 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
|
||||
unsqueeze_dim
|
||||
)
|
||||
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
return q_embed, k_embed
|
||||
return _apply_npu_rotary_emb(q, k, cos, sin)
|
||||
|
||||
|
||||
@register_kernel
|
||||
@@ -96,12 +134,12 @@ class NpuRoPEKernel(BaseKernel):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||
"""Apply RoPE acceleration by monkey-patching ``apply_rotary_pos_emb``.
|
||||
|
||||
This function iterates through the model's modules to find attention layers,
|
||||
identifies the module where they are defined, and replaces the original
|
||||
`apply_rotary_pos_emb` function in that module's namespace with the
|
||||
NPU-accelerated version from this file.
|
||||
Iterates through the model's modules to find attention layers, identifies
|
||||
the module where they are defined, and replaces the original
|
||||
``apply_rotary_pos_emb`` function in that module's namespace with the
|
||||
NPU-accelerated version.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
@@ -110,7 +148,7 @@ class NpuRoPEKernel(BaseKernel):
|
||||
HFModel: The model with patched RoPE functions.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If dependencies are not met.
|
||||
RuntimeError: If ``torch_npu`` is not available.
|
||||
ValueError: If the model is not provided.
|
||||
"""
|
||||
if not cls.check_deps():
|
||||
|
||||
Reference in New Issue
Block a user