From 45d335c709b677aef205afd62db78ab2c2bcde9a Mon Sep 17 00:00:00 2001 From: jiaqiw09 <60021713+jiaqiw09@users.noreply.github.com> Date: Sat, 28 Feb 2026 18:16:06 +0800 Subject: [PATCH] [v1] add seed for training and fix gradient checkpointing (#10211) --- examples/v1/train_full/train_full_fsdp2.yaml | 4 ---- src/llamafactory/v1/config/arg_parser.py | 9 +++++++++ src/llamafactory/v1/config/training_args.py | 6 +++++- src/llamafactory/v1/core/base_trainer.py | 3 ++- src/llamafactory/v1/core/utils/batching.py | 9 ++++++++- .../v1/plugins/trainer_plugins/distributed/fsdp2.py | 9 ++++----- src/llamafactory/v1/utils/helper.py | 10 ++++++++++ 7 files changed, 38 insertions(+), 12 deletions(-) diff --git a/examples/v1/train_full/train_full_fsdp2.yaml b/examples/v1/train_full/train_full_fsdp2.yaml index dfad62022..57ac6a1f3 100644 --- a/examples/v1/train_full/train_full_fsdp2.yaml +++ b/examples/v1/train_full/train_full_fsdp2.yaml @@ -14,16 +14,12 @@ dist_config: name: fsdp2 dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp -init_config: - name: init_on_meta - ### data train_dataset: data/v1_sft_demo.yaml ### training output_dir: outputs/test_fsdp2 micro_batch_size: 1 -global_batch_size: 1 cutoff_len: 2048 learning_rate: 1.0e-4 bf16: false diff --git a/src/llamafactory/v1/config/arg_parser.py b/src/llamafactory/v1/config/arg_parser.py index 2122a569f..0a0caddd2 100644 --- a/src/llamafactory/v1/config/arg_parser.py +++ b/src/llamafactory/v1/config/arg_parser.py @@ -21,6 +21,7 @@ from omegaconf import OmegaConf from transformers import HfArgumentParser from ..utils.env import is_env_enabled +from ..utils.helper import set_seed from .data_args import DataArguments from .model_args import ModelArguments from .sample_args import SampleArguments @@ -56,6 +57,14 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments, print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") + # Seed as early as possible after argument parsing so all downstream + # components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state. + for arg in parsed_args: + seed = getattr(arg, "seed", None) + if seed is not None: + set_seed(seed) + break + return tuple(parsed_args) diff --git a/src/llamafactory/v1/config/training_args.py b/src/llamafactory/v1/config/training_args.py index 8fe0c1cf1..5d13ef2fb 100644 --- a/src/llamafactory/v1/config/training_args.py +++ b/src/llamafactory/v1/config/training_args.py @@ -66,7 +66,7 @@ class TrainingArguments: metadata={"help": "Number of workers for batching."}, ) enable_activation_checkpointing: bool = field( - default=True, + default=False, metadata={"help": "Enable activation checkpointing for training."}, ) dist_config: PluginConfig | None = field( @@ -81,6 +81,10 @@ class TrainingArguments: default=None, metadata={"help": "Learning rate scheduler configuration for training."}, ) + seed: int = field( + default=42, + metadata={"help": "Random seed that will be set at the beginning of training."}, + ) def __post_init__(self) -> None: self.dist_config = get_plugin_config(self.dist_config) diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index 69660361e..3e277e435 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -76,7 +76,7 @@ class BaseTrainer: if self.args.enable_activation_checkpointing: self.model.gradient_checkpointing_enable({"use_reentrant": False}) - self._accelerate_engine = None + self._deepspeed_engine = None dist_name = self.args.dist_config.name if self.args.dist_config is not None else None if dist_name == "deepspeed": @@ -108,6 +108,7 @@ class BaseTrainer: cutoff_len=self.args.cutoff_len, batching_workers=self.args.batching_workers, batching_strategy=self.args.batching_strategy, + seed=self.args.seed, ) def _shard_model(self) -> None: diff --git a/src/llamafactory/v1/core/utils/batching.py b/src/llamafactory/v1/core/utils/batching.py index e87a95974..2243fb078 100644 --- a/src/llamafactory/v1/core/utils/batching.py +++ b/src/llamafactory/v1/core/utils/batching.py @@ -26,6 +26,7 @@ from collections.abc import Iterator from typing import Any +import torch from torch.utils.data import default_collate from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler @@ -71,6 +72,7 @@ class BatchGenerator(Iterator): batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL, pin_memory: bool = True, drop_last: bool = True, + seed: int = 42, ) -> None: self.dataset = dataset self.renderer = renderer @@ -82,6 +84,7 @@ class BatchGenerator(Iterator): self.batching_strategy = batching_strategy self.pin_memory = pin_memory self.drop_last = drop_last + self.seed = seed # TODO: support length and infinity dp_size = DistributedInterface().get_world_size(Dim.DP) @@ -128,12 +131,15 @@ class BatchGenerator(Iterator): num_replicas=DistributedInterface().get_world_size(Dim.DP), rank=DistributedInterface().get_rank(Dim.DP), shuffle=True, - seed=0, + seed=self.seed, drop_last=self.drop_last, ) else: raise NotImplementedError("Iterable dataset is not supported yet.") + generato_seed = torch.Generator() + generato_seed.manual_seed(self.seed) + self._data_provider = StatefulDataLoader( self.dataset, batch_size=self.micro_batch_size * self.num_micro_batch, @@ -143,6 +149,7 @@ class BatchGenerator(Iterator): pin_memory=self.pin_memory, pin_memory_device=DistributedInterface().current_device.type, drop_last=self.drop_last, + generator=generato_seed, ) if self.batching_strategy == BatchingStrategy.NORMAL: self._length = len(self._data_provider) diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py index dbe2626bf..f32607627 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py @@ -166,12 +166,11 @@ class FSDP2Engine: offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None, ) - use_gradient_checkpointing = True # Could be configurable - if use_gradient_checkpointing: + # BaseTrainer is the single source of truth for gradient checkpointing. + # FSDP2 only applies the input-grad compatibility hook when checkpointing is already enabled. + if getattr(model, "is_gradient_checkpointing", False): if self.rank == 0: - logger.info("Enabling gradient checkpointing (transformers native)...") - - model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + logger.info("Gradient checkpointing is enabled. Applying FSDP2 input grad preparation.") if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() diff --git a/src/llamafactory/v1/utils/helper.py b/src/llamafactory/v1/utils/helper.py index 3f7b75505..dd453093c 100644 --- a/src/llamafactory/v1/utils/helper.py +++ b/src/llamafactory/v1/utils/helper.py @@ -15,12 +15,22 @@ import torch from transformers import PreTrainedTokenizer +from transformers import set_seed as hf_set_seed from ..accelerator.interface import DistributedInterface from .constants import IGNORE_INDEX from .types import BatchInput, ModelInput, Processor, Tensor +def set_seed(seed: int) -> None: + """Set seed for reproducibility. + + Args: + seed: Random seed. + """ + hf_set_seed(seed) + + def is_tokenizer(processor: Processor) -> bool: """Check if processor is tokenizer.