mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
Bug Fix: off
is parsed as False
in yaml file, changed to disabled
to avoid this.
Former-commit-id: 3ed063f281d1c2563df1b9eb3800543208c9dc16
This commit is contained in:
parent
bb9f48590f
commit
9aa640f27b
@ -97,7 +97,7 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||||
)
|
)
|
||||||
flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
|
flash_attn: Literal["disabled", "sdpa", "fa2", "auto"] = field(
|
||||||
default="auto",
|
default="auto",
|
||||||
metadata={"help": "Enable FlashAttention for faster training and inference."},
|
metadata={"help": "Enable FlashAttention for faster training and inference."},
|
||||||
)
|
)
|
||||||
|
@ -102,6 +102,10 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
|||||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||||
|
|
||||||
|
# In case that `flash_attn` is set to `off` in the yaml file, and parsed as `False` afterwards.
|
||||||
|
if model_args.flash_attn == False:
|
||||||
|
raise ValueError("flash_attn should be \"disabled\", \"sdpa\", \"fa2\" or \"auto\".")
|
||||||
|
|
||||||
|
|
||||||
def _check_extra_dependencies(
|
def _check_extra_dependencies(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
|
@ -32,7 +32,7 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
|
|||||||
if model_args.flash_attn == "auto":
|
if model_args.flash_attn == "auto":
|
||||||
return
|
return
|
||||||
|
|
||||||
elif model_args.flash_attn == "off":
|
elif model_args.flash_attn == "disabled":
|
||||||
requested_attn_implementation = "eager"
|
requested_attn_implementation = "eager"
|
||||||
|
|
||||||
elif model_args.flash_attn == "sdpa":
|
elif model_args.flash_attn == "sdpa":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user