[v1] add sft (#9752)

This commit is contained in:
Yaowei Zheng
2026-01-12 03:15:01 +08:00
committed by GitHub
parent 4d3621e3d3
commit 958b9c3468
29 changed files with 439 additions and 305 deletions

View File

@@ -33,13 +33,21 @@ class TrainingArguments:
default=None,
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Maximum sequence length for training."},
)
learning_rate: float = field(
default=1e-4,
metadata={"help": "Learning rate for training."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Maximum sequence length for training."},
num_train_epochs: int = field(
default=3,
metadata={"help": "Number of training epochs."},
)
max_grad_norm: float = field(
default=1.0,
metadata={"help": "Maximum gradient norm for training."},
)
bf16: bool = field(
default=False,
@@ -53,10 +61,24 @@ class TrainingArguments:
default=16,
metadata={"help": "Number of workers for batching."},
)
enable_activation_checkpointing: bool = field(
default=True,
metadata={"help": "Enable activation checkpointing for training."},
)
dist_config: PluginConfig | None = field(
default=None,
metadata={"help": "Distribution configuration for training."},
)
optim_config: PluginConfig | None = field(
default=None,
metadata={"help": "Optimizer configuration for training."},
)
lr_scheduler_config: PluginConfig | None = field(
default=None,
metadata={"help": "Learning rate scheduler configuration for training."},
)
def __post_init__(self) -> None:
self.dist_config = get_plugin_config(self.dist_config)
self.optim_config = get_plugin_config(self.optim_config)
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)