mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-13 09:30:34 +08:00
[v1] model loader (#9613)
This commit is contained in:
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -28,6 +27,9 @@ from .sample_args import SampleArguments
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
InputArgument = Optional[Union[dict[str, Any], list[str]]]
|
||||
|
||||
|
||||
def validate_args(
|
||||
data_args: DataArguments,
|
||||
model_args: ModelArguments,
|
||||
@@ -43,9 +45,7 @@ def validate_args(
|
||||
raise ValueError("Quantization is not supported with deepspeed backend.")
|
||||
|
||||
|
||||
def get_args(
|
||||
args: Optional[Union[dict[str, Any], list[str]]] = None,
|
||||
) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
||||
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
||||
"""Parse arguments from command line or config file."""
|
||||
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
|
||||
import json
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class PluginConfig(dict):
|
||||
@@ -32,25 +32,16 @@ class PluginConfig(dict):
|
||||
|
||||
return self["name"]
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {key} not found.")
|
||||
|
||||
def __setattr__(self, key: str, value: Any):
|
||||
self[key] = value
|
||||
|
||||
|
||||
PluginArgument = Optional[Union[PluginConfig, dict, str]]
|
||||
|
||||
|
||||
@unique
|
||||
class AutoClass(str, Enum):
|
||||
class ModelClass(str, Enum):
|
||||
"""Auto class for model config."""
|
||||
|
||||
CAUSALLM = "llm"
|
||||
CLASSIFICATION = "cls"
|
||||
LLM = "llm"
|
||||
CLS = "cls"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
|
||||
@@ -14,8 +14,9 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from .arg_utils import AutoClass, PluginConfig, get_plugin_config
|
||||
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -31,19 +32,19 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Use fast processor from Hugging Face."},
|
||||
)
|
||||
auto_class: AutoClass = field(
|
||||
default=AutoClass.CAUSALLM,
|
||||
model_class: ModelClass = field(
|
||||
default=ModelClass.LLM,
|
||||
metadata={"help": "Model class from Hugging Face."},
|
||||
)
|
||||
peft_config: PluginConfig = field(
|
||||
peft_config: Optional[PluginConfig] = field(
|
||||
default=None,
|
||||
metadata={"help": "PEFT configuration for the model."},
|
||||
)
|
||||
kernel_config: PluginConfig = field(
|
||||
kernel_config: Optional[PluginConfig] = field(
|
||||
default=None,
|
||||
metadata={"help": "Kernel configuration for the model."},
|
||||
)
|
||||
quant_config: PluginConfig = field(
|
||||
quant_config: Optional[PluginConfig] = field(
|
||||
default=None,
|
||||
metadata={"help": "Quantization configuration for the model."},
|
||||
)
|
||||
|
||||
@@ -12,16 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from .arg_utils import PluginArgument, get_plugin_config
|
||||
from .arg_utils import PluginConfig, get_plugin_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments:
|
||||
output_dir: str = field(
|
||||
default="",
|
||||
default=os.path.join("outputs", str(uuid4())),
|
||||
metadata={"help": "Path to the output directory."},
|
||||
)
|
||||
micro_batch_size: int = field(
|
||||
@@ -40,7 +42,7 @@ class TrainingArguments:
|
||||
default=False,
|
||||
metadata={"help": "Use bf16 for training."},
|
||||
)
|
||||
dist_config: PluginArgument = field(
|
||||
dist_config: Optional[PluginConfig] = field(
|
||||
default=None,
|
||||
metadata={"help": "Distribution configuration for training."},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user