mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-07 11:22:20 +08:00
Compare commits
3 Commits
a3c2b6139c
...
4f2f058d42
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f2f058d42 | ||
|
|
d55091ea87 | ||
|
|
44131fdb2a |
@ -24,9 +24,6 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||||||
from ..extras.constants import EngineName
|
from ..extras.constants import EngineName
|
||||||
from ..extras.misc import torch_gc
|
from ..extras.misc import torch_gc
|
||||||
from ..hparams import get_infer_args
|
from ..hparams import get_infer_args
|
||||||
from .hf_engine import HuggingfaceEngine
|
|
||||||
from .sglang_engine import SGLangEngine
|
|
||||||
from .vllm_engine import VllmEngine
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -49,12 +46,28 @@ class ChatModel:
|
|||||||
|
|
||||||
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
|
||||||
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||||
|
|
||||||
if model_args.infer_backend == EngineName.HF:
|
if model_args.infer_backend == EngineName.HF:
|
||||||
|
from .hf_engine import HuggingfaceEngine
|
||||||
self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
elif model_args.infer_backend == EngineName.VLLM:
|
elif model_args.infer_backend == EngineName.VLLM:
|
||||||
self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
try:
|
||||||
|
from .vllm_engine import VllmEngine
|
||||||
|
self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"vLLM not install, you may need to run `pip install vllm`\n"
|
||||||
|
"or try to use HuggingFace backend: --infer_backend huggingface"
|
||||||
|
) from e
|
||||||
elif model_args.infer_backend == EngineName.SGLANG:
|
elif model_args.infer_backend == EngineName.SGLANG:
|
||||||
self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
|
try:
|
||||||
|
from .sglang_engine import SGLangEngine
|
||||||
|
self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"SGLang not install, you may need to run `pip install sglang[all]`\n"
|
||||||
|
"or try to use HuggingFace backend: --infer_backend huggingface"
|
||||||
|
) from e
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
||||||
|
|
||||||
|
|||||||
@ -35,16 +35,46 @@ USAGE = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_api():
|
||||||
|
from .api.app import run_api
|
||||||
|
return run_api()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_chat():
|
||||||
|
from .chat.chat_model import run_chat
|
||||||
|
return run_chat()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_eval():
|
||||||
|
from .eval.evaluator import run_eval
|
||||||
|
return run_eval()
|
||||||
|
|
||||||
|
|
||||||
|
def _export_model():
|
||||||
|
from .train.tuner import export_model
|
||||||
|
return export_model()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_exp():
|
||||||
|
from .train.tuner import run_exp
|
||||||
|
return run_exp()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_web_demo():
|
||||||
|
from .webui.interface import run_web_demo
|
||||||
|
return run_web_demo()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_web_ui():
|
||||||
|
from .webui.interface import run_web_ui
|
||||||
|
return run_web_ui()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
from . import launcher
|
from . import launcher
|
||||||
from .api.app import run_api
|
|
||||||
from .chat.chat_model import run_chat
|
|
||||||
from .eval.evaluator import run_eval
|
|
||||||
from .extras import logging
|
from .extras import logging
|
||||||
from .extras.env import VERSION, print_env
|
from .extras.env import VERSION, print_env
|
||||||
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
|
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
|
||||||
from .train.tuner import export_model, run_exp
|
|
||||||
from .webui.interface import run_web_demo, run_web_ui
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -61,14 +91,14 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
COMMAND_MAP = {
|
COMMAND_MAP = {
|
||||||
"api": run_api,
|
"api": _run_api,
|
||||||
"chat": run_chat,
|
"chat": _run_chat,
|
||||||
"env": print_env,
|
"env": print_env,
|
||||||
"eval": run_eval,
|
"eval": _run_eval,
|
||||||
"export": export_model,
|
"export": _export_model,
|
||||||
"train": run_exp,
|
"train": _run_exp,
|
||||||
"webchat": run_web_demo,
|
"webchat": _run_web_demo,
|
||||||
"webui": run_web_ui,
|
"webui": _run_web_ui,
|
||||||
"version": partial(print, WELCOME),
|
"version": partial(print, WELCOME),
|
||||||
"help": partial(print, USAGE),
|
"help": partial(print, USAGE),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -416,8 +416,8 @@ class ReasoningTemplate(Template):
|
|||||||
|
|
||||||
prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
|
prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
|
||||||
if (
|
if (
|
||||||
self.thought_words[0] not in messages[-1]["content"]
|
self.thought_words[0].strip() not in messages[-1]["content"]
|
||||||
and self.thought_words[1] not in messages[-1]["content"]
|
and self.thought_words[1].strip() not in messages[-1]["content"]
|
||||||
): # add empty cot
|
): # add empty cot
|
||||||
if not self.enable_thinking: # do not compute loss
|
if not self.enable_thinking: # do not compute loss
|
||||||
prompt_ids += self.get_thought_word_ids(tokenizer)
|
prompt_ids += self.get_thought_word_ids(tokenizer)
|
||||||
@ -442,8 +442,8 @@ class ReasoningTemplate(Template):
|
|||||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||||
for i in range(0, len(messages), 2):
|
for i in range(0, len(messages), 2):
|
||||||
if (
|
if (
|
||||||
self.thought_words[0] not in messages[i + 1]["content"]
|
self.thought_words[0].strip() not in messages[i + 1]["content"]
|
||||||
and self.thought_words[1] not in messages[i + 1]["content"]
|
and self.thought_words[1].strip() not in messages[i + 1]["content"]
|
||||||
): # add empty cot
|
): # add empty cot
|
||||||
if not self.enable_thinking: # do not compute loss
|
if not self.enable_thinking: # do not compute loss
|
||||||
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
|
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
|
||||||
|
|||||||
131
src/llamafactory/model/model_utils/sdpa_npu_redirect.py
Normal file
131
src/llamafactory/model/model_utils/sdpa_npu_redirect.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.utils import is_torch_npu_available
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ORIG_SDPA = F.scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
|
def _to_bool_4d_mask(
|
||||||
|
attn_mask: Optional[torch.Tensor], q_len: int, kv_len: int, device: torch.device
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""Normalize additive/other Hugging Face masks into a boolean mask of shape [B, 1, Q, K] (True = masked)."""
|
||||||
|
if attn_mask is None:
|
||||||
|
return None
|
||||||
|
if attn_mask.dtype != torch.bool:
|
||||||
|
attn_mask = attn_mask < 0 # additive -inf -> True
|
||||||
|
if attn_mask.dim() == 4:
|
||||||
|
return attn_mask[..., :q_len, :kv_len].contiguous()
|
||||||
|
if attn_mask.dim() == 3:
|
||||||
|
return attn_mask[:, None, :q_len, :kv_len].contiguous()
|
||||||
|
if attn_mask.dim() == 2:
|
||||||
|
return attn_mask[:, None, None, :kv_len].expand(-1, 1, q_len, -1).contiguous()
|
||||||
|
return attn_mask.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_causal_mask(
|
||||||
|
attn_mask: Optional[torch.Tensor], is_causal: bool, L: int, S: int, device: torch.device
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""Merge `is_causal` into the boolean/additive attention mask (True = masked)."""
|
||||||
|
if not is_causal or L != S:
|
||||||
|
return attn_mask
|
||||||
|
causal_bool = torch.ones((1, 1, L, L), dtype=torch.bool, device=device).triu(1)
|
||||||
|
if attn_mask is None:
|
||||||
|
return causal_bool
|
||||||
|
if attn_mask.dtype != torch.bool:
|
||||||
|
attn_mask = attn_mask < 0
|
||||||
|
if attn_mask.dim() == 2:
|
||||||
|
attn_mask = attn_mask[:, None, None, :L].expand(-1, 1, L, -1).contiguous()
|
||||||
|
elif attn_mask.dim() == 3:
|
||||||
|
attn_mask = attn_mask[:, None, :L, :L].contiguous()
|
||||||
|
return attn_mask | causal_bool
|
||||||
|
|
||||||
|
|
||||||
|
def _sdpa_npu_redirect(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
dropout_p: float = 0.0,
|
||||||
|
is_causal: bool = False,
|
||||||
|
scale: Optional[float] = None,
|
||||||
|
):
|
||||||
|
"""A drop-in replacement for `F.scaled_dot_product_attention`.
|
||||||
|
|
||||||
|
Automatically falls back to the native SDPA when conditions are not met.
|
||||||
|
The NPU-fused path is only enabled when q/k/v have shape (B, N, S, D); otherwise, it falls back.
|
||||||
|
"""
|
||||||
|
# Fall back if the feature is disabled or the conditions are not satisfied.
|
||||||
|
if os.environ.get("NPU_FA_DISABLE", "0") == "1":
|
||||||
|
return _ORIG_SDPA(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||||
|
|
||||||
|
npu_ok = is_torch_npu_available() and (q.device.type == "npu")
|
||||||
|
dtype_ok = q.dtype in (torch.float16, torch.bfloat16)
|
||||||
|
shape_ok = q.dim() == 4 and k.dim() == 4 and v.dim() == 4 # 期望 BNSD
|
||||||
|
if not (npu_ok and dtype_ok and shape_ok):
|
||||||
|
return _ORIG_SDPA(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||||
|
|
||||||
|
L, S = q.size(-2), k.size(-2)
|
||||||
|
merged_mask = _merge_causal_mask(attn_mask, is_causal, L, S, q.device)
|
||||||
|
mask_bool = _to_bool_4d_mask(merged_mask, q_len=L, kv_len=S, device=q.device)
|
||||||
|
|
||||||
|
head_dim = q.size(-1)
|
||||||
|
sc = (1.0 / math.sqrt(head_dim)) if (scale is None) else scale
|
||||||
|
|
||||||
|
train_mode = torch.is_grad_enabled() and (dropout_p > 0)
|
||||||
|
keep_prob = 1.0 - (dropout_p if train_mode else 0.0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
out = torch_npu.npu_fusion_attention(
|
||||||
|
q.contiguous(),
|
||||||
|
k.contiguous(),
|
||||||
|
v.contiguous(),
|
||||||
|
head_num=q.size(-3), # N
|
||||||
|
input_layout="BNSD", # (B, N, S, D)
|
||||||
|
pse=None,
|
||||||
|
atten_mask=mask_bool, # True = masked
|
||||||
|
scale=sc,
|
||||||
|
pre_tockens=2147483647,
|
||||||
|
next_tockens=2147483647,
|
||||||
|
keep_prob=keep_prob,
|
||||||
|
sync=False,
|
||||||
|
inner_precise=0,
|
||||||
|
)[0]
|
||||||
|
return out
|
||||||
|
except Exception as e:
|
||||||
|
if os.environ.get("NPU_FA_VERBOSE", "0") == "1":
|
||||||
|
logger.warning(f"[sdpa_npu_redirect] npu_fusion_attention failed: {e}; fallback to SDPA.")
|
||||||
|
return _ORIG_SDPA(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_sdpa_npu_redirect(verbose: bool = True):
|
||||||
|
"""Install the redirection by pointing `F.scaled_dot_product_attention` to our implementation."""
|
||||||
|
if getattr(F.scaled_dot_product_attention, "__wrapped_by_npu__", False):
|
||||||
|
return
|
||||||
|
F.scaled_dot_product_attention = _sdpa_npu_redirect
|
||||||
|
setattr(F.scaled_dot_product_attention, "__wrapped_by_npu__", True)
|
||||||
|
if verbose:
|
||||||
|
logger.info("[sdpa_npu_redirect] SDPA has been redirected to Ascend npu_fusion_attention when available.")
|
||||||
@ -188,6 +188,23 @@ def patch_model(
|
|||||||
if not model_args.use_unsloth:
|
if not model_args.use_unsloth:
|
||||||
print_attn_implementation(model.config)
|
print_attn_implementation(model.config)
|
||||||
|
|
||||||
|
# ======== NPU fused attention redirect: SDPA -> torch_npu.npu_fusion_attention ========
|
||||||
|
# Place after all structural modifications and before DeepSpeed/Trainer initialization;
|
||||||
|
# does not modify any Module/_parameters, safe for ZeRO-3 + offload.
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if hasattr(torch, "npu") and torch.npu.is_available() and os.environ.get("NPU_FA_DISABLE", "0") != "1":
|
||||||
|
from .model_utils.sdpa_npu_redirect import apply_sdpa_npu_redirect
|
||||||
|
|
||||||
|
apply_sdpa_npu_redirect(verbose=not model_args.use_unsloth)
|
||||||
|
logger.info_rank0("[sdpa_npu_redirect] Enabled: SDPA will use Ascend npu_fusion_attention when available.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning_rank0(f"[sdpa_npu_redirect] Failed to enable redirect, will keep native SDPA. Reason: {e}")
|
||||||
|
# =====================================================================================
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model.add_model_tags(["llama-factory"])
|
model.add_model_tags(["llama-factory"])
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user