mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-28 02:39:03 +08:00
[v1] add seed for training and fix gradient checkpointing (#10211)
This commit is contained in:
@@ -76,7 +76,7 @@ class BaseTrainer:
|
||||
if self.args.enable_activation_checkpointing:
|
||||
self.model.gradient_checkpointing_enable({"use_reentrant": False})
|
||||
|
||||
self._accelerate_engine = None
|
||||
self._deepspeed_engine = None
|
||||
dist_name = self.args.dist_config.name if self.args.dist_config is not None else None
|
||||
|
||||
if dist_name == "deepspeed":
|
||||
@@ -108,6 +108,7 @@ class BaseTrainer:
|
||||
cutoff_len=self.args.cutoff_len,
|
||||
batching_workers=self.args.batching_workers,
|
||||
batching_strategy=self.args.batching_strategy,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
|
||||
def _shard_model(self) -> None:
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.utils.data import default_collate
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||
@@ -71,6 +72,7 @@ class BatchGenerator(Iterator):
|
||||
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
|
||||
pin_memory: bool = True,
|
||||
drop_last: bool = True,
|
||||
seed: int = 42,
|
||||
) -> None:
|
||||
self.dataset = dataset
|
||||
self.renderer = renderer
|
||||
@@ -82,6 +84,7 @@ class BatchGenerator(Iterator):
|
||||
self.batching_strategy = batching_strategy
|
||||
self.pin_memory = pin_memory
|
||||
self.drop_last = drop_last
|
||||
self.seed = seed
|
||||
# TODO: support length and infinity
|
||||
dp_size = DistributedInterface().get_world_size(Dim.DP)
|
||||
|
||||
@@ -128,12 +131,15 @@ class BatchGenerator(Iterator):
|
||||
num_replicas=DistributedInterface().get_world_size(Dim.DP),
|
||||
rank=DistributedInterface().get_rank(Dim.DP),
|
||||
shuffle=True,
|
||||
seed=0,
|
||||
seed=self.seed,
|
||||
drop_last=self.drop_last,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||
|
||||
generato_seed = torch.Generator()
|
||||
generato_seed.manual_seed(self.seed)
|
||||
|
||||
self._data_provider = StatefulDataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.micro_batch_size * self.num_micro_batch,
|
||||
@@ -143,6 +149,7 @@ class BatchGenerator(Iterator):
|
||||
pin_memory=self.pin_memory,
|
||||
pin_memory_device=DistributedInterface().current_device.type,
|
||||
drop_last=self.drop_last,
|
||||
generator=generato_seed,
|
||||
)
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
self._length = len(self._data_provider)
|
||||
|
||||
Reference in New Issue
Block a user