[v1] model loader (#9613)

This commit is contained in:
Yaowei Zheng
2025-12-14 11:50:52 +08:00
committed by GitHub
parent fdd24276ed
commit aeda079014
27 changed files with 449 additions and 305 deletions

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import sys
from pathlib import Path
@@ -28,6 +27,9 @@ from .sample_args import SampleArguments
from .training_args import TrainingArguments
InputArgument = Optional[Union[dict[str, Any], list[str]]]
def validate_args(
data_args: DataArguments,
model_args: ModelArguments,
@@ -43,9 +45,7 @@ def validate_args(
raise ValueError("Quantization is not supported with deepspeed backend.")
def get_args(
args: Optional[Union[dict[str, Any], list[str]]] = None,
) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
"""Parse arguments from command line or config file."""
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")

View File

@@ -18,7 +18,7 @@
import json
from enum import Enum, unique
from typing import Any, Optional, Union
from typing import Optional, Union
class PluginConfig(dict):
@@ -32,25 +32,16 @@ class PluginConfig(dict):
return self["name"]
def __getattr__(self, key: str) -> Any:
try:
return self[key]
except KeyError:
raise AttributeError(f"Attribute {key} not found.")
def __setattr__(self, key: str, value: Any):
self[key] = value
PluginArgument = Optional[Union[PluginConfig, dict, str]]
@unique
class AutoClass(str, Enum):
class ModelClass(str, Enum):
"""Auto class for model config."""
CAUSALLM = "llm"
CLASSIFICATION = "cls"
LLM = "llm"
CLS = "cls"
OTHER = "other"

View File

@@ -14,8 +14,9 @@
from dataclasses import dataclass, field
from typing import Optional
from .arg_utils import AutoClass, PluginConfig, get_plugin_config
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@dataclass
@@ -31,19 +32,19 @@ class ModelArguments:
default=True,
metadata={"help": "Use fast processor from Hugging Face."},
)
auto_class: AutoClass = field(
default=AutoClass.CAUSALLM,
model_class: ModelClass = field(
default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."},
)
peft_config: PluginConfig = field(
peft_config: Optional[PluginConfig] = field(
default=None,
metadata={"help": "PEFT configuration for the model."},
)
kernel_config: PluginConfig = field(
kernel_config: Optional[PluginConfig] = field(
default=None,
metadata={"help": "Kernel configuration for the model."},
)
quant_config: PluginConfig = field(
quant_config: Optional[PluginConfig] = field(
default=None,
metadata={"help": "Quantization configuration for the model."},
)

View File

@@ -12,16 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from typing import Optional
from uuid import uuid4
from .arg_utils import PluginArgument, get_plugin_config
from .arg_utils import PluginConfig, get_plugin_config
@dataclass
class TrainingArguments:
output_dir: str = field(
default="",
default=os.path.join("outputs", str(uuid4())),
metadata={"help": "Path to the output directory."},
)
micro_batch_size: int = field(
@@ -40,7 +42,7 @@ class TrainingArguments:
default=False,
metadata={"help": "Use bf16 for training."},
)
dist_config: PluginArgument = field(
dist_config: Optional[PluginConfig] = field(
default=None,
metadata={"help": "Distribution configuration for training."},
)