[v1] add accelerator (#9607)

This commit is contained in:
Yaowei Zheng
2025-12-12 19:22:06 +08:00
committed by GitHub
parent 4fd94141a4
commit 203069e11c
36 changed files with 941 additions and 443 deletions

View File

@@ -15,6 +15,8 @@
from dataclasses import dataclass, field
from .arg_utils import AutoClass, PluginConfig, get_plugin_config
@dataclass
class ModelArguments:
@@ -29,7 +31,24 @@ class ModelArguments:
default=True,
metadata={"help": "Use fast processor from Hugging Face."},
)
auto_model_class: str = field(
default="causallm",
auto_class: AutoClass = field(
default=AutoClass.CAUSALLM,
metadata={"help": "Model class from Hugging Face."},
)
peft_config: PluginConfig = field(
default=None,
metadata={"help": "PEFT configuration for the model."},
)
kernel_config: PluginConfig = field(
default=None,
metadata={"help": "Kernel configuration for the model."},
)
quant_config: PluginConfig = field(
default=None,
metadata={"help": "Quantization configuration for the model."},
)
def __post_init__(self) -> None:
self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config)
self.quant_config = get_plugin_config(self.quant_config)