diff --git a/examples/accelerate/fsdp2_config_qwen35.yaml b/examples/accelerate/fsdp2_config_qwen35.yaml new file mode 100644 index 000000000..edc8b8ffb --- /dev/null +++ b/examples/accelerate/fsdp2_config_qwen35.yaml @@ -0,0 +1,20 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +fsdp_config: + fsdp_version: 2 + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer,Qwen3_5VisionBlock + fsdp_cpu_ram_efficient_loading: true + fsdp_offload_params: false + fsdp_reshard_after_forward: true + fsdp_state_dict_type: FULL_STATE_DICT +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 # Change to match your NPU count (e.g., 8 for A2, 16 for A3) +rdzv_backend: static +same_network: true +use_cpu: false diff --git a/examples/ascend/qwen3_5_full_sft_fsdp2.yaml b/examples/ascend/qwen3_5_full_sft_fsdp2.yaml new file mode 100644 index 000000000..915b70f81 --- /dev/null +++ b/examples/ascend/qwen3_5_full_sft_fsdp2.yaml @@ -0,0 +1,47 @@ +# Start FSDP2 full fine-tuning on Ascend NPU +# Usage: +# accelerate launch \ +# --config_file examples/accelerate/fsdp2_config_qwen35.yaml \ +# src/train.py examples/ascend/qwen3_5_full_sft_fsdp2.yaml +# +# Note: Change `num_processes` in fsdp2_config_qwen35.yaml to match your NPU count + +### model +model_name_or_path: Qwen/Qwen3.5-4B +trust_remote_code: true +use_v1_kernels: true +flash_attn: fa2 + +### method +stage: sft +do_train: true +finetuning_type: full + +### dataset +dataset: alpaca_en_demo +template: qwen3_5_nothink +cutoff_len: 2048 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 +dataloader_num_workers: 4 + +### output +output_dir: saves/Qwen3.5-4B/full/sft +logging_steps: 1 +save_steps: 500 +max_steps: 500 +plot_loss: true +overwrite_output_dir: true +save_only_model: false +report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] + +### train +per_device_train_batch_size: 8 +gradient_accumulation_steps: 1 +learning_rate: 1.0e-5 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 1800 +resume_from_checkpoint: null diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py index 35057f451..3b1c39f88 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py @@ -23,25 +23,104 @@ Init Phase: import re import types +import torch +import torch.nn.functional as F + from ......accelerator.helper import DeviceType from ......utils.types import HFModel from ...base import BaseKernel from ...registry import register_kernel -def npu_rms_norm_forward(self, hidden_states): - """NPU forward implementation for RMSNorm. +try: + import torch_npu +except ImportError: + pass + + +def _should_use_residual_rmsnorm(module): + """Detect whether the module uses residual RMSNorm parameterization. + + Residual RMSNorm uses ``scale = 1.0 + weight`` where weight is initialized to 0, + while standard RMSNorm uses ``scale = weight`` where weight is initialized to 1. Args: - self: RMSNorm module instance with `weight` and `variance_epsilon`. - hidden_states (Tensor): Input hidden states tensor, same shape as the baseline. + module (nn.Module): The RMSNorm module to check. + + Returns: + bool: ``True`` if the module uses residual parameterization, ``False`` otherwise. + + .. note:: + This detection ensures compatibility with future model versions (e.g., Qwen3.6, Qwen4.0) + without hardcoding version numbers. Two methods are used: weight value inspection + (most reliable) and class name pattern matching (backward compatibility). + """ + if hasattr(module, "weight") and module.weight is not None: + weight_mean = module.weight.data.mean().item() + if abs(weight_mean) < 0.3: + return True + + class_name = module.__class__.__name__ + residual_patterns = ["Qwen3_5", "Qwen3_6", "Qwen4"] + for pattern in residual_patterns: + if pattern in class_name: + return True + + return False + + +def npu_rms_norm_forward(self, hidden_states): + """NPU forward implementation for standard RMSNorm. + + Args: + self (nn.Module): The RMSNorm module instance with ``weight`` and ``variance_epsilon``. + hidden_states (Tensor): Input hidden states tensor. Returns: Tensor: Normalized tensor consistent with the baseline RMSNorm behavior. """ - import torch_npu + _eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6) - return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] + if hasattr(self, "weight") and self.weight is not None: + if _should_use_residual_rmsnorm(self): + effective_weight = 1.0 + self.weight.float() + else: + effective_weight = self.weight.float() + else: + effective_weight = None + + if effective_weight is not None: + return torch_npu.npu_rms_norm(hidden_states, effective_weight.to(hidden_states.dtype), epsilon=_eps)[0] + else: + return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=_eps)[0] + + +def npu_gated_rms_norm_forward(self, hidden_states, gate=None): + """NPU forward implementation for Gated RMSNorm with high-precision FP32 computation. + + This function performs RMSNorm and gated SiLU multiplication in FP32 for numerical + stability. Unlike standard RMSNorm, Gated RMSNorm in Qwen3.5 uses standard + parameterization (``scale = weight`` where weight is initialized to 1), so the + residual weight adjustment (``1.0 + weight``) is not applied here. + + Args: + self (nn.Module): The Gated RMSNorm module instance. + hidden_states (Tensor): Input hidden states tensor. + gate (Tensor, optional): Gate tensor for SiLU activation. Defaults to ``None``. + + Returns: + Tensor: Output tensor cast back to the original input dtype. + """ + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + _eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6) + + hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight.float(), epsilon=_eps)[0] + + if gate is not None: + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + + return hidden_states.to(input_dtype) @register_kernel @@ -55,12 +134,9 @@ class NpuRMSNormKernel(BaseKernel): def apply(cls, **kwargs) -> "HFModel": """Iterate the model and apply NPU-optimized forward to matched RMSNorm modules. - Key points: - - Match modules whose class name contains "RMSNorm" (case-insensitive). - - Bind `_npu_rms_forward` as an instance method via `types.MethodType` to - replace the original `forward`. - - Do not modify weights, hyperparameters, or module structure to ensure - numerical behavior and interface consistency. + Matches modules whose class name contains "RMSNorm" (case-insensitive) and binds + the appropriate NPU-optimized forward function as an instance method via + ``types.MethodType`` to replace the original ``forward``. Args: **kwargs: Keyword arguments containing the model. @@ -69,7 +145,7 @@ class NpuRMSNormKernel(BaseKernel): HFModel: The model with NPU fused RMSNorm. Raises: - RuntimeError: If torch_npu is not available. + RuntimeError: If ``torch_npu`` is not available. ValueError: If the model is not provided. """ model = kwargs.get("model") @@ -81,11 +157,11 @@ class NpuRMSNormKernel(BaseKernel): rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE) - for name, module in model.named_modules(): - # Match any module whose class name contains "RMSNorm" + for _, module in model.named_modules(): if re.search(rms_norm_pattern, module.__class__.__name__): - # Bind function as an instance method to preserve `self` semantics - # and replace the original forward - module.forward = types.MethodType(npu_rms_norm_forward, module) + if "Gated" in module.__class__.__name__: + module.forward = types.MethodType(npu_gated_rms_norm_forward, module) + else: + module.forward = types.MethodType(npu_rms_norm_forward, module) return model diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py index 2f3e290a0..0f9d4d7e2 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py @@ -39,40 +39,80 @@ except ImportError: pass -def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors using NPU optimization. +def _apply_npu_rotary_emb(q, k, cos, sin): + """Apply NPU-accelerated rotary embedding with automatic Partial RoPE detection. + + This function automatically detects whether to use Partial RoPE or Full RoPE + based on the dimension ratio between ``cos/sin`` and ``q/k`` tensors, ensuring + compatibility with future model versions without hardcoding. Args: q (Tensor): Query tensor. k (Tensor): Key tensor. - cos (Tensor): Cosine part of embedding. - sin (Tensor): Sine part of embedding. - position_ids (Tensor, optional): Position IDs. Default: ``None``. - unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1. + cos (Tensor): Cosine part of rotary embedding (already unsqueezed). + sin (Tensor): Sine part of rotary embedding (already unsqueezed). Returns: - tuple: (q_embed, k_embed) The embedded query and key tensors. + tuple[Tensor, Tensor]: The embedded query and key tensors ``(q_embed, k_embed)``. """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = torch_npu.npu_rotary_mul(q, cos, sin) - k_embed = torch_npu.npu_rotary_mul(k, cos, sin) + rotary_dim = cos.shape[-1] + query_dim = q.shape[-1] + + if rotary_dim < query_dim: + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_embed = torch_npu.npu_rotary_mul(q_rot, cos, sin).to(q.dtype) + k_embed = torch_npu.npu_rotary_mul(k_rot, cos, sin).to(k.dtype) + + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + else: + q_embed = torch_npu.npu_rotary_mul(q, cos, sin).to(q.dtype) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin).to(k.dtype) + return q_embed, k_embed -def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1): - """Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU. +def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Apply Rotary Position Embedding to query and key tensors using NPU optimization. + + This function automatically supports both Full RoPE and Partial RoPE based on + the dimension ratio between ``cos/sin`` and ``q/k`` tensors. Args: q (Tensor): Query tensor. k (Tensor): Key tensor. cos (Tensor): Cosine part of embedding. sin (Tensor): Sine part of embedding. - mrope_section (Tensor): Multimodal RoPE section. - unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1. + position_ids (Tensor, optional): Position IDs. Defaults to ``None``. + unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Defaults to 1. Returns: - tuple: (q_embed, k_embed) The embedded query and key tensors. + tuple[Tensor, Tensor]: The embedded query and key tensors ``(q_embed, k_embed)``. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + return _apply_npu_rotary_emb(q, k, cos, sin) + + +def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Apply Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU. + + This function supports Partial RoPE for multimodal inputs with automatic dimension + detection, ensuring compatibility with future model versions. + + Args: + q (Tensor): Query tensor. + k (Tensor): Key tensor. + cos (Tensor): Cosine part of embedding. + sin (Tensor): Sine part of embedding. + mrope_section (list[int]): Multimodal RoPE section sizes. + unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Defaults to 1. + + Returns: + tuple[Tensor, Tensor]: The embedded query and key tensors ``(q_embed, k_embed)``. """ mrope_section = mrope_section * 2 cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( @@ -82,9 +122,7 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un unsqueeze_dim ) - q_embed = torch_npu.npu_rotary_mul(q, cos, sin) - k_embed = torch_npu.npu_rotary_mul(k, cos, sin) - return q_embed, k_embed + return _apply_npu_rotary_emb(q, k, cos, sin) @register_kernel @@ -96,12 +134,12 @@ class NpuRoPEKernel(BaseKernel): @classmethod def apply(cls, **kwargs) -> "HFModel": - """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`. + """Apply RoPE acceleration by monkey-patching ``apply_rotary_pos_emb``. - This function iterates through the model's modules to find attention layers, - identifies the module where they are defined, and replaces the original - `apply_rotary_pos_emb` function in that module's namespace with the - NPU-accelerated version from this file. + Iterates through the model's modules to find attention layers, identifies + the module where they are defined, and replaces the original + ``apply_rotary_pos_emb`` function in that module's namespace with the + NPU-accelerated version. Args: **kwargs: Keyword arguments containing the model. @@ -110,7 +148,7 @@ class NpuRoPEKernel(BaseKernel): HFModel: The model with patched RoPE functions. Raises: - RuntimeError: If dependencies are not met. + RuntimeError: If ``torch_npu`` is not available. ValueError: If the model is not provided. """ if not cls.check_deps():