mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	reenable sdpa and fast tok by default
Former-commit-id: 9e00902dbedc71d55743d1bf237843506a557891
This commit is contained in:
		
							parent
							
								
									35c4a2c212
								
							
						
					
					
						commit
						d2bb1b3a6b
					
				@ -72,8 +72,6 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
 | 
			
		||||
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See `examples/extras/mod` for usage.
 | 
			
		||||
 | 
			
		||||
[24/04/19] We supported **Meta Llama 3** model series.
 | 
			
		||||
 | 
			
		||||
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See `examples/extras/badam` for usage.
 | 
			
		||||
 | 
			
		||||
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
 | 
			
		||||
@ -112,7 +110,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
 | 
			
		||||
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
 | 
			
		||||
 | 
			
		||||
[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
 | 
			
		||||
[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
 | 
			
		||||
 | 
			
		||||
[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -72,8 +72,6 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
 | 
			
		||||
 | 
			
		||||
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 `examples/extras/mod`。
 | 
			
		||||
 | 
			
		||||
[24/04/19] 我们支持了 **Meta Llama 3** 系列模型。
 | 
			
		||||
 | 
			
		||||
[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 `examples/extras/badam`。
 | 
			
		||||
 | 
			
		||||
[24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练(24GB 可训练 Llama-2-7B-56k)。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
 | 
			
		||||
@ -112,7 +110,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
 | 
			
		||||
 | 
			
		||||
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
 | 
			
		||||
 | 
			
		||||
[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2。
 | 
			
		||||
[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn fa2` 参数以启用 FlashAttention-2。
 | 
			
		||||
 | 
			
		||||
[23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,3 +15,4 @@ fastapi
 | 
			
		||||
sse-starlette
 | 
			
		||||
matplotlib
 | 
			
		||||
fire
 | 
			
		||||
packaging
 | 
			
		||||
 | 
			
		||||
@ -1,16 +1,23 @@
 | 
			
		||||
import importlib.metadata
 | 
			
		||||
import importlib.util
 | 
			
		||||
from typing import TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
from packaging import version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from packaging.version import Version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _is_package_available(name: str) -> bool:
 | 
			
		||||
    return importlib.util.find_spec(name) is not None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_package_version(name: str) -> str:
 | 
			
		||||
def _get_package_version(name: str) -> "Version":
 | 
			
		||||
    try:
 | 
			
		||||
        return importlib.metadata.version(name)
 | 
			
		||||
        return version.parse(importlib.metadata.version(name))
 | 
			
		||||
    except Exception:
 | 
			
		||||
        return "0.0.0"
 | 
			
		||||
        return version.parse("0.0.0")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_fastapi_availble():
 | 
			
		||||
@ -18,7 +25,7 @@ def is_fastapi_availble():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_flash_attn2_available():
 | 
			
		||||
    return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
 | 
			
		||||
    return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_galore_available():
 | 
			
		||||
@ -49,6 +56,10 @@ def is_rouge_available():
 | 
			
		||||
    return _is_package_available("rouge_chinese")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_sdpa_available():
 | 
			
		||||
    return _get_package_version("torch") > version.parse("2.1.1")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_starlette_available():
 | 
			
		||||
    return _is_package_available("sse_starlette")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ class ModelArguments:
 | 
			
		||||
        metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
 | 
			
		||||
    )
 | 
			
		||||
    use_fast_tokenizer: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
 | 
			
		||||
    )
 | 
			
		||||
    resize_vocab: bool = field(
 | 
			
		||||
@ -61,9 +61,9 @@ class ModelArguments:
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
 | 
			
		||||
    )
 | 
			
		||||
    flash_attn: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Enable FlashAttention for faster training."},
 | 
			
		||||
    flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
 | 
			
		||||
        default="auto",
 | 
			
		||||
        metadata={"help": "Enable FlashAttention for faster training and inference."},
 | 
			
		||||
    )
 | 
			
		||||
    shift_attn: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
 | 
			
		||||
from ..extras.logging import get_logger
 | 
			
		||||
from ..extras.misc import get_current_device, infer_optim_dtype
 | 
			
		||||
from ..extras.packages import is_flash_attn2_available
 | 
			
		||||
from ..extras.packages import is_flash_attn2_available, is_sdpa_available
 | 
			
		||||
from ..extras.patches.llama_patch import apply_llama_patch
 | 
			
		||||
from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable
 | 
			
		||||
 | 
			
		||||
@ -62,18 +62,45 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
 | 
			
		||||
    if model_args.flash_attn:
 | 
			
		||||
        if not is_flash_attn2_available():
 | 
			
		||||
            logger.warning("FlashAttention2 is not installed.")
 | 
			
		||||
    if model_args.flash_attn == "auto":
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    elif model_args.flash_attn == "off":
 | 
			
		||||
        requested_attn_implementation = "eager"
 | 
			
		||||
 | 
			
		||||
    elif model_args.flash_attn == "sdpa":
 | 
			
		||||
        if not is_sdpa_available():
 | 
			
		||||
            logger.warning("Torch>=2.1.1 is required for SDPA attention.")
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        logger.info("Using FlashAttention-2 for faster training and inference.")
 | 
			
		||||
        if getattr(config, "model_type", None) == "internlm2":  # special case for custom models
 | 
			
		||||
            setattr(config, "attn_implementation", "flash_attention_2")
 | 
			
		||||
        else:
 | 
			
		||||
            setattr(config, "_attn_implementation", "flash_attention_2")
 | 
			
		||||
        requested_attn_implementation = "sdpa"
 | 
			
		||||
    elif model_args.flash_attn == "fa2":
 | 
			
		||||
        if not is_flash_attn2_available():
 | 
			
		||||
            logger.warning("FlashAttention-2 is not installed.")
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        requested_attn_implementation = "flash_attention_2"
 | 
			
		||||
    else:
 | 
			
		||||
        setattr(config, "_attn_implementation", "eager")
 | 
			
		||||
        raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
 | 
			
		||||
 | 
			
		||||
    if getattr(config, "model_type", None) == "internlm2":  # special case for custom models
 | 
			
		||||
        setattr(config, "attn_implementation", requested_attn_implementation)
 | 
			
		||||
    else:
 | 
			
		||||
        setattr(config, "_attn_implementation", requested_attn_implementation)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _print_attn_implementation(config: "PretrainedConfig") -> None:
 | 
			
		||||
    if getattr(config, "model_type", None) == "internlm2":  # special case for custom models
 | 
			
		||||
        attn_implementation = getattr(config, "attn_implementation", None)
 | 
			
		||||
    else:
 | 
			
		||||
        attn_implementation = getattr(config, "_attn_implementation", None)
 | 
			
		||||
 | 
			
		||||
    if attn_implementation == "flash_attention_2":
 | 
			
		||||
        logger.info("Using FlashAttention-2 for faster training and inference.")
 | 
			
		||||
    elif attn_implementation == "sdpa":
 | 
			
		||||
        logger.info("Using torch SDPA for faster training and inference.")
 | 
			
		||||
    else:
 | 
			
		||||
        logger.info("Using vanilla Attention implementation.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
 | 
			
		||||
@ -365,6 +392,8 @@ def patch_model(
 | 
			
		||||
 | 
			
		||||
        add_z3_leaf_module(model, Qwen2MoeSparseMoeBlock)
 | 
			
		||||
 | 
			
		||||
    _print_attn_implementation(model.config)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        model.add_model_tags(["llama-factory"])
 | 
			
		||||
    except Exception:
 | 
			
		||||
 | 
			
		||||
@ -33,7 +33,7 @@ def create_top() -> Dict[str, "Component"]:
 | 
			
		||||
            quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
 | 
			
		||||
            template = gr.Dropdown(choices=list(templates.keys()), value="default")
 | 
			
		||||
            rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
 | 
			
		||||
            booster = gr.Radio(choices=["none", "flashattn", "unsloth"], value="none")
 | 
			
		||||
            booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none")
 | 
			
		||||
 | 
			
		||||
    model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
 | 
			
		||||
        get_model_path, [model_name], [model_path], queue=False
 | 
			
		||||
 | 
			
		||||
@ -67,7 +67,7 @@ class Runner:
 | 
			
		||||
        if not model_path:
 | 
			
		||||
            return ALERTS["err_no_path"][lang]
 | 
			
		||||
 | 
			
		||||
        if len(dataset) == 0:
 | 
			
		||||
        if not dataset:
 | 
			
		||||
            return ALERTS["err_no_dataset"][lang]
 | 
			
		||||
 | 
			
		||||
        if not from_preview and self.demo_mode:
 | 
			
		||||
@ -122,7 +122,7 @@ class Runner:
 | 
			
		||||
            quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
 | 
			
		||||
            template=get("top.template"),
 | 
			
		||||
            rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
 | 
			
		||||
            flash_attn=(get("top.booster") == "flashattn"),
 | 
			
		||||
            flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
 | 
			
		||||
            use_unsloth=(get("top.booster") == "unsloth"),
 | 
			
		||||
            dataset_dir=get("train.dataset_dir"),
 | 
			
		||||
            dataset=",".join(get("train.dataset")),
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user