mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-27 17:20:35 +08:00
[v1] model loader (#9613)
This commit is contained in:
@@ -12,16 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from .arg_utils import PluginArgument, get_plugin_config
|
||||
from .arg_utils import PluginConfig, get_plugin_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments:
|
||||
output_dir: str = field(
|
||||
default="",
|
||||
default=os.path.join("outputs", str(uuid4())),
|
||||
metadata={"help": "Path to the output directory."},
|
||||
)
|
||||
micro_batch_size: int = field(
|
||||
@@ -40,7 +42,7 @@ class TrainingArguments:
|
||||
default=False,
|
||||
metadata={"help": "Use bf16 for training."},
|
||||
)
|
||||
dist_config: PluginArgument = field(
|
||||
dist_config: Optional[PluginConfig] = field(
|
||||
default=None,
|
||||
metadata={"help": "Distribution configuration for training."},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user