mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 20:30:36 +08:00
[config] update args (#7231)
Former-commit-id: f71a901840811bf560df671ec63a146ff99140c6
This commit is contained in:
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import AttentionFunction
|
||||
from ...extras.misc import check_version
|
||||
|
||||
|
||||
@@ -33,34 +34,34 @@ def configure_attn_implementation(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||
) -> None:
|
||||
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
|
||||
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
|
||||
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
|
||||
if is_flash_attn_2_available():
|
||||
check_version("transformers>=4.42.4")
|
||||
check_version("flash_attn>=2.6.3")
|
||||
if model_args.flash_attn != "fa2":
|
||||
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||
model_args.flash_attn = "fa2"
|
||||
if model_args.flash_attn != AttentionFunction.FA2:
|
||||
logger.warning_rank0("Gemma 2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||
model_args.flash_attn = AttentionFunction.FA2
|
||||
else:
|
||||
logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
|
||||
model_args.flash_attn = "disabled"
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
model_args.flash_attn = AttentionFunction.DISABLED
|
||||
elif model_args.flash_attn == AttentionFunction.SDPA:
|
||||
logger.warning_rank0(
|
||||
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
|
||||
)
|
||||
|
||||
if model_args.flash_attn == "auto":
|
||||
if model_args.flash_attn == AttentionFunction.AUTO:
|
||||
return
|
||||
|
||||
elif model_args.flash_attn == "disabled":
|
||||
elif model_args.flash_attn == AttentionFunction.DISABLED:
|
||||
requested_attn_implementation = "eager"
|
||||
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
elif model_args.flash_attn == AttentionFunction.SDPA:
|
||||
if not is_torch_sdpa_available():
|
||||
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
|
||||
return
|
||||
|
||||
requested_attn_implementation = "sdpa"
|
||||
elif model_args.flash_attn == "fa2":
|
||||
elif model_args.flash_attn == AttentionFunction.FA2:
|
||||
if not is_flash_attn_2_available():
|
||||
logger.warning_rank0("FlashAttention-2 is not installed.")
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user