mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-06 19:56:01 +08:00
[v1] add sft (#9752)
This commit is contained in:
@@ -30,9 +30,9 @@ from .training_args import TrainingArguments
|
||||
InputArgument = dict[str, Any] | list[str] | None
|
||||
|
||||
|
||||
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
||||
def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments, TrainingArguments, SampleArguments]:
|
||||
"""Parse arguments from command line or config file."""
|
||||
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
|
||||
parser = HfArgumentParser([ModelArguments, DataArguments, TrainingArguments, SampleArguments])
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")
|
||||
|
||||
if args is None:
|
||||
|
||||
@@ -18,7 +18,11 @@ from dataclasses import dataclass, field
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
dataset: str | None = field(
|
||||
train_dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the dataset."},
|
||||
metadata={"help": "Path to the training dataset."},
|
||||
)
|
||||
eval_dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the evaluation dataset."},
|
||||
)
|
||||
|
||||
@@ -33,13 +33,21 @@ class TrainingArguments:
|
||||
default=None,
|
||||
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "Maximum sequence length for training."},
|
||||
)
|
||||
learning_rate: float = field(
|
||||
default=1e-4,
|
||||
metadata={"help": "Learning rate for training."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "Maximum sequence length for training."},
|
||||
num_train_epochs: int = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of training epochs."},
|
||||
)
|
||||
max_grad_norm: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Maximum gradient norm for training."},
|
||||
)
|
||||
bf16: bool = field(
|
||||
default=False,
|
||||
@@ -53,10 +61,24 @@ class TrainingArguments:
|
||||
default=16,
|
||||
metadata={"help": "Number of workers for batching."},
|
||||
)
|
||||
enable_activation_checkpointing: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Enable activation checkpointing for training."},
|
||||
)
|
||||
dist_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Distribution configuration for training."},
|
||||
)
|
||||
optim_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Optimizer configuration for training."},
|
||||
)
|
||||
lr_scheduler_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Learning rate scheduler configuration for training."},
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.dist_config = get_plugin_config(self.dist_config)
|
||||
self.optim_config = get_plugin_config(self.optim_config)
|
||||
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)
|
||||
|
||||
Reference in New Issue
Block a user