[v1] support training with fsdp2 (#9773)

Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
浮梦
2026-01-25 19:41:58 +08:00
committed by GitHub
parent 641bfdd482
commit f9f11dcb97
15 changed files with 801 additions and 33 deletions

View File

@@ -67,7 +67,11 @@ class BaseTrainer:
self.model_input_names = self.renderer.processor.model_input_names
self._create_batch_generator()
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
# Calculate num_training_steps: max_steps takes priority if set
if self.args.max_steps is not None and self.args.max_steps > 0:
self.num_training_steps = self.args.max_steps
else:
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
if self.args.enable_activation_checkpointing:
self.model.gradient_checkpointing_enable({"use_reentrant": False})
@@ -98,7 +102,22 @@ class BaseTrainer:
)
def _shard_model(self) -> None:
pass
if self.args.dist_config is None:
if DistributedInterface().get_world_size(Dim.DP) > 1:
from torch.nn.parallel import DistributedDataParallel as DDP
logger.warning_rank0(
"dist_config is None but distributed training is enabled; falling back to DistributedDataParallel."
)
device_ids = None if self.device.type == "cpu" else [self.device.index]
self.model = DDP(self.model, device_ids=device_ids)
else:
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
self.model = DistributedPlugin(self.args.dist_config.name)(
self.model,
self.args.dist_config,
)
def _init_optimizer(self) -> None:
"""Init optimizer."""
@@ -162,7 +181,9 @@ class BaseTrainer:
step_loss += loss.item()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
if not torch.isfinite(grad_norm):
# isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
else:
self.optimizer.step()
@@ -172,10 +193,17 @@ class BaseTrainer:
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
DistributedInterface().sync()
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
if DistributedInterface().get_rank() == 0:
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
# Check if max_steps is reached
if self.global_step >= self.num_training_steps:
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
return
def save_model(self) -> None:
"""Save the model."""
self.model.save_pretrained(self.args.output_dir)
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir)
self.renderer.processor.save_pretrained(self.args.output_dir)
logger.info_rank0(f"Model saved to {self.args.output_dir}")

View File

@@ -30,7 +30,7 @@ from torch.utils.data import default_collate
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from ...accelerator.interface import DistributedInterface
from ...accelerator.interface import Dim, DistributedInterface
from ...config import BatchingStrategy
from ...utils import logging
from ...utils.helper import pad_and_truncate
@@ -83,8 +83,7 @@ class BatchGenerator(Iterator):
self.pin_memory = pin_memory
self.drop_last = drop_last
# TODO: support length and infinity
dp_size = DistributedInterface().get_world_size("dp")
dp_size = DistributedInterface().get_world_size(Dim.DP)
if self.global_batch_size is None:
self.global_batch_size = dp_size * micro_batch_size
@@ -126,8 +125,8 @@ class BatchGenerator(Iterator):
if len(self.dataset) != -1:
sampler = StatefulDistributedSampler(
self.dataset,
num_replicas=DistributedInterface().get_world_size("dp"),
rank=DistributedInterface().get_rank("dp"),
num_replicas=DistributedInterface().get_world_size(Dim.DP),
rank=DistributedInterface().get_rank(Dim.DP),
shuffle=True,
seed=0,
drop_last=self.drop_last,
@@ -142,6 +141,7 @@ class BatchGenerator(Iterator):
num_workers=self.batching_workers,
collate_fn=self.renderer.process_samples,
pin_memory=self.pin_memory,
pin_memory_device=DistributedInterface().current_device.type,
drop_last=self.drop_last,
)
if self.batching_strategy == BatchingStrategy.NORMAL: