mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 12:20:37 +08:00
[config] update args (#7231)
This commit is contained in:
@@ -23,6 +23,8 @@ import torch
|
||||
from transformers.training_args import _convert_str_dict
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..extras.constants import AttentionFunction, EngineName, RopeScaling
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelArguments:
|
||||
@@ -77,12 +79,12 @@ class BaseModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
||||
)
|
||||
rope_scaling: Optional[Literal["linear", "dynamic", "yarn", "llama3"]] = field(
|
||||
rope_scaling: Optional[RopeScaling] = field(
|
||||
default=None,
|
||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||
)
|
||||
flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field(
|
||||
default="auto",
|
||||
flash_attn: AttentionFunction = field(
|
||||
default=AttentionFunction.AUTO,
|
||||
metadata={"help": "Enable FlashAttention for faster training and inference."},
|
||||
)
|
||||
shift_attn: bool = field(
|
||||
@@ -129,8 +131,8 @@ class BaseModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to randomly initialize the model weights."},
|
||||
)
|
||||
infer_backend: Literal["huggingface", "vllm"] = field(
|
||||
default="huggingface",
|
||||
infer_backend: EngineName = field(
|
||||
default=EngineName.HF,
|
||||
metadata={"help": "Backend engine used at inference."},
|
||||
)
|
||||
offload_folder: str = field(
|
||||
|
||||
Reference in New Issue
Block a user