mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-04 02:35:59 +08:00
[v1] add seed for training and fix gradient checkpointing (#10211)
This commit is contained in:
@@ -21,6 +21,7 @@ from omegaconf import OmegaConf
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from ..utils.env import is_env_enabled
|
||||
from ..utils.helper import set_seed
|
||||
from .data_args import DataArguments
|
||||
from .model_args import ModelArguments
|
||||
from .sample_args import SampleArguments
|
||||
@@ -56,6 +57,14 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
# Seed as early as possible after argument parsing so all downstream
|
||||
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
|
||||
for arg in parsed_args:
|
||||
seed = getattr(arg, "seed", None)
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
break
|
||||
|
||||
return tuple(parsed_args)
|
||||
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class TrainingArguments:
|
||||
metadata={"help": "Number of workers for batching."},
|
||||
)
|
||||
enable_activation_checkpointing: bool = field(
|
||||
default=True,
|
||||
default=False,
|
||||
metadata={"help": "Enable activation checkpointing for training."},
|
||||
)
|
||||
dist_config: PluginConfig | None = field(
|
||||
@@ -81,6 +81,10 @@ class TrainingArguments:
|
||||
default=None,
|
||||
metadata={"help": "Learning rate scheduler configuration for training."},
|
||||
)
|
||||
seed: int = field(
|
||||
default=42,
|
||||
metadata={"help": "Random seed that will be set at the beginning of training."},
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.dist_config = get_plugin_config(self.dist_config)
|
||||
|
||||
Reference in New Issue
Block a user