diff --git a/src/llamafactory/model/model_utils/sdpa_npu_redirect.py b/src/llamafactory/model/model_utils/sdpa_npu_redirect.py deleted file mode 100644 index 7275ad5d..00000000 --- a/src/llamafactory/model/model_utils/sdpa_npu_redirect.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import math -import os -from typing import Optional - -import torch -import torch.nn.functional as F -from transformers.utils import is_torch_npu_available - - -logger = logging.getLogger(__name__) - -_ORIG_SDPA = F.scaled_dot_product_attention - - -def _to_bool_4d_mask( - attn_mask: Optional[torch.Tensor], q_len: int, kv_len: int, device: torch.device -) -> Optional[torch.Tensor]: - """Normalize additive/other Hugging Face masks into a boolean mask of shape [B, 1, Q, K] (True = masked).""" - if attn_mask is None: - return None - if attn_mask.dtype != torch.bool: - attn_mask = attn_mask < 0 # additive -inf -> True - if attn_mask.dim() == 4: - return attn_mask[..., :q_len, :kv_len].contiguous() - if attn_mask.dim() == 3: - return attn_mask[:, None, :q_len, :kv_len].contiguous() - if attn_mask.dim() == 2: - return attn_mask[:, None, None, :kv_len].expand(-1, 1, q_len, -1).contiguous() - return attn_mask.to(device) - - -def _merge_causal_mask( - attn_mask: Optional[torch.Tensor], is_causal: bool, L: int, S: int, device: torch.device -) -> Optional[torch.Tensor]: - """Merge `is_causal` into the boolean/additive attention mask (True = masked).""" - if not is_causal or L != S: - return attn_mask - causal_bool = torch.ones((1, 1, L, L), dtype=torch.bool, device=device).triu(1) - if attn_mask is None: - return causal_bool - if attn_mask.dtype != torch.bool: - attn_mask = attn_mask < 0 - if attn_mask.dim() == 2: - attn_mask = attn_mask[:, None, None, :L].expand(-1, 1, L, -1).contiguous() - elif attn_mask.dim() == 3: - attn_mask = attn_mask[:, None, :L, :L].contiguous() - return attn_mask | causal_bool - - -def _sdpa_npu_redirect( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, -): - """A drop-in replacement for `F.scaled_dot_product_attention`. - - Automatically falls back to the native SDPA when conditions are not met. - The NPU-fused path is only enabled when q/k/v have shape (B, N, S, D); otherwise, it falls back. - """ - # Fall back if the feature is disabled or the conditions are not satisfied. - if os.environ.get("NPU_FA_DISABLE", "0") == "1": - return _ORIG_SDPA(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) - - npu_ok = is_torch_npu_available() and (q.device.type == "npu") - dtype_ok = q.dtype in (torch.float16, torch.bfloat16) - shape_ok = q.dim() == 4 and k.dim() == 4 and v.dim() == 4 # 期望 BNSD - if not (npu_ok and dtype_ok and shape_ok): - return _ORIG_SDPA(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) - - L, S = q.size(-2), k.size(-2) - merged_mask = _merge_causal_mask(attn_mask, is_causal, L, S, q.device) - mask_bool = _to_bool_4d_mask(merged_mask, q_len=L, kv_len=S, device=q.device) - - head_dim = q.size(-1) - sc = (1.0 / math.sqrt(head_dim)) if (scale is None) else scale - - train_mode = torch.is_grad_enabled() and (dropout_p > 0) - keep_prob = 1.0 - (dropout_p if train_mode else 0.0) - - try: - import torch_npu - - out = torch_npu.npu_fusion_attention( - q.contiguous(), - k.contiguous(), - v.contiguous(), - head_num=q.size(-3), # N - input_layout="BNSD", # (B, N, S, D) - pse=None, - atten_mask=mask_bool, # True = masked - scale=sc, - pre_tockens=2147483647, - next_tockens=2147483647, - keep_prob=keep_prob, - sync=False, - inner_precise=0, - )[0] - return out - except Exception as e: - if os.environ.get("NPU_FA_VERBOSE", "0") == "1": - logger.warning(f"[sdpa_npu_redirect] npu_fusion_attention failed: {e}; fallback to SDPA.") - return _ORIG_SDPA(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) - - -def apply_sdpa_npu_redirect(verbose: bool = True): - """Install the redirection by pointing `F.scaled_dot_product_attention` to our implementation.""" - if getattr(F.scaled_dot_product_attention, "__wrapped_by_npu__", False): - return - F.scaled_dot_product_attention = _sdpa_npu_redirect - setattr(F.scaled_dot_product_attention, "__wrapped_by_npu__", True) - if verbose: - logger.info("[sdpa_npu_redirect] SDPA has been redirected to Ascend npu_fusion_attention when available.") diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index d4681f5a..cbf9aea3 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -193,23 +193,6 @@ def patch_model( if not model_args.use_unsloth: print_attn_implementation(model.config) - # ======== NPU fused attention redirect: SDPA -> torch_npu.npu_fusion_attention ======== - # Place after all structural modifications and before DeepSpeed/Trainer initialization; - # does not modify any Module/_parameters, safe for ZeRO-3 + offload. - try: - import os - - import torch - - if hasattr(torch, "npu") and torch.npu.is_available() and os.environ.get("NPU_FA_DISABLE", "0") != "1": - from .model_utils.sdpa_npu_redirect import apply_sdpa_npu_redirect - - apply_sdpa_npu_redirect(verbose=not model_args.use_unsloth) - logger.info_rank0("[sdpa_npu_redirect] Enabled: SDPA will use Ascend npu_fusion_attention when available.") - except Exception as e: - logger.warning_rank0(f"[sdpa_npu_redirect] Failed to enable redirect, will keep native SDPA. Reason: {e}") - # ===================================================================================== - try: model.add_model_tags(["llama-factory"]) except Exception: