mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-13 01:20:35 +08:00
[v1] add accelerator (#9607)
This commit is contained in:
@@ -15,6 +15,8 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .arg_utils import AutoClass, PluginConfig, get_plugin_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
@@ -29,7 +31,24 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Use fast processor from Hugging Face."},
|
||||
)
|
||||
auto_model_class: str = field(
|
||||
default="causallm",
|
||||
auto_class: AutoClass = field(
|
||||
default=AutoClass.CAUSALLM,
|
||||
metadata={"help": "Model class from Hugging Face."},
|
||||
)
|
||||
peft_config: PluginConfig = field(
|
||||
default=None,
|
||||
metadata={"help": "PEFT configuration for the model."},
|
||||
)
|
||||
kernel_config: PluginConfig = field(
|
||||
default=None,
|
||||
metadata={"help": "Kernel configuration for the model."},
|
||||
)
|
||||
quant_config: PluginConfig = field(
|
||||
default=None,
|
||||
metadata={"help": "Quantization configuration for the model."},
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.peft_config = get_plugin_config(self.peft_config)
|
||||
self.kernel_config = get_plugin_config(self.kernel_config)
|
||||
self.quant_config = get_plugin_config(self.quant_config)
|
||||
|
||||
Reference in New Issue
Block a user