mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-13 09:30:34 +08:00
[v1] add init plugin (#9716)
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .arg_parser import InputArgument, get_args
|
||||
from .arg_utils import ModelClass, SampleBackend
|
||||
from .data_args import DataArguments
|
||||
from .model_args import ModelArguments
|
||||
from .sample_args import SampleArguments
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DataArguments",
|
||||
"InputArgument",
|
||||
"ModelArguments",
|
||||
"ModelClass",
|
||||
"SampleArguments",
|
||||
"SampleBackend",
|
||||
"TrainingArguments",
|
||||
"get_args",
|
||||
]
|
||||
|
||||
@@ -27,14 +27,14 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Trust remote code from Hugging Face."},
|
||||
)
|
||||
use_fast_processor: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Use fast processor from Hugging Face."},
|
||||
)
|
||||
model_class: ModelClass = field(
|
||||
default=ModelClass.LLM,
|
||||
metadata={"help": "Model class from Hugging Face."},
|
||||
)
|
||||
init_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Initialization configuration for the model."},
|
||||
)
|
||||
peft_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "PEFT configuration for the model."},
|
||||
@@ -49,6 +49,7 @@ class ModelArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.init_config = get_plugin_config(self.init_config)
|
||||
self.peft_config = get_plugin_config(self.peft_config)
|
||||
self.kernel_config = get_plugin_config(self.kernel_config)
|
||||
self.quant_config = get_plugin_config(self.quant_config)
|
||||
|
||||
@@ -22,7 +22,7 @@ from .arg_utils import PluginConfig, get_plugin_config
|
||||
@dataclass
|
||||
class TrainingArguments:
|
||||
output_dir: str = field(
|
||||
default=os.path.join("outputs", str(uuid4())),
|
||||
default=os.path.join("outputs", str(uuid4().hex)),
|
||||
metadata={"help": "Path to the output directory."},
|
||||
)
|
||||
micro_batch_size: int = field(
|
||||
|
||||
Reference in New Issue
Block a user