mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 12:50:38 +08:00
add npu examples
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
@@ -44,6 +45,10 @@ def patch_config(
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
|
||||
if is_torch_npu_available():
|
||||
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
|
||||
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
|
||||
|
||||
configure_attn_implementation(config, model_args)
|
||||
configure_rope(config, model_args, is_trainable)
|
||||
configure_longlora(config, model_args, is_trainable)
|
||||
@@ -56,7 +61,7 @@ def patch_config(
|
||||
logger.info("Using KV cache for faster generation.")
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
setattr(config, "use_flash_attn", model_args.flash_attn)
|
||||
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
setattr(config, dtype_name, model_args.compute_dtype == dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user