mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-27 17:20:35 +08:00
[v1] add accelerator (#9607)
This commit is contained in:
@@ -15,6 +15,8 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .arg_utils import PluginArgument, get_plugin_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments:
|
||||
@@ -38,3 +40,10 @@ class TrainingArguments:
|
||||
default=False,
|
||||
metadata={"help": "Use bf16 for training."},
|
||||
)
|
||||
dist_config: PluginArgument = field(
|
||||
default=None,
|
||||
metadata={"help": "Distribution configuration for training."},
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.dist_config = get_plugin_config(self.dist_config)
|
||||
|
||||
Reference in New Issue
Block a user