mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-28 02:39:03 +08:00
[v1] support training with fsdp2 (#9773)
Co-authored-by: frozenleaves <frozen@Mac.local> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -67,7 +67,11 @@ class BaseTrainer:
|
||||
self.model_input_names = self.renderer.processor.model_input_names
|
||||
|
||||
self._create_batch_generator()
|
||||
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
|
||||
# Calculate num_training_steps: max_steps takes priority if set
|
||||
if self.args.max_steps is not None and self.args.max_steps > 0:
|
||||
self.num_training_steps = self.args.max_steps
|
||||
else:
|
||||
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
|
||||
|
||||
if self.args.enable_activation_checkpointing:
|
||||
self.model.gradient_checkpointing_enable({"use_reentrant": False})
|
||||
@@ -98,7 +102,22 @@ class BaseTrainer:
|
||||
)
|
||||
|
||||
def _shard_model(self) -> None:
|
||||
pass
|
||||
if self.args.dist_config is None:
|
||||
if DistributedInterface().get_world_size(Dim.DP) > 1:
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
logger.warning_rank0(
|
||||
"dist_config is None but distributed training is enabled; falling back to DistributedDataParallel."
|
||||
)
|
||||
device_ids = None if self.device.type == "cpu" else [self.device.index]
|
||||
self.model = DDP(self.model, device_ids=device_ids)
|
||||
else:
|
||||
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
|
||||
|
||||
self.model = DistributedPlugin(self.args.dist_config.name)(
|
||||
self.model,
|
||||
self.args.dist_config,
|
||||
)
|
||||
|
||||
def _init_optimizer(self) -> None:
|
||||
"""Init optimizer."""
|
||||
@@ -162,7 +181,9 @@ class BaseTrainer:
|
||||
step_loss += loss.item()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
|
||||
if not torch.isfinite(grad_norm):
|
||||
|
||||
# isfinite(): argument 'input' (position 1) must be Tensor, not float
|
||||
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
|
||||
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
|
||||
else:
|
||||
self.optimizer.step()
|
||||
@@ -172,10 +193,17 @@ class BaseTrainer:
|
||||
|
||||
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
|
||||
DistributedInterface().sync()
|
||||
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
|
||||
|
||||
# Check if max_steps is reached
|
||||
if self.global_step >= self.num_training_steps:
|
||||
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
|
||||
return
|
||||
|
||||
def save_model(self) -> None:
|
||||
"""Save the model."""
|
||||
self.model.save_pretrained(self.args.output_dir)
|
||||
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
||||
model_to_save.save_pretrained(self.args.output_dir)
|
||||
self.renderer.processor.save_pretrained(self.args.output_dir)
|
||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||
|
||||
@@ -30,7 +30,7 @@ from torch.utils.data import default_collate
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||
|
||||
from ...accelerator.interface import DistributedInterface
|
||||
from ...accelerator.interface import Dim, DistributedInterface
|
||||
from ...config import BatchingStrategy
|
||||
from ...utils import logging
|
||||
from ...utils.helper import pad_and_truncate
|
||||
@@ -83,8 +83,7 @@ class BatchGenerator(Iterator):
|
||||
self.pin_memory = pin_memory
|
||||
self.drop_last = drop_last
|
||||
# TODO: support length and infinity
|
||||
|
||||
dp_size = DistributedInterface().get_world_size("dp")
|
||||
dp_size = DistributedInterface().get_world_size(Dim.DP)
|
||||
|
||||
if self.global_batch_size is None:
|
||||
self.global_batch_size = dp_size * micro_batch_size
|
||||
@@ -126,8 +125,8 @@ class BatchGenerator(Iterator):
|
||||
if len(self.dataset) != -1:
|
||||
sampler = StatefulDistributedSampler(
|
||||
self.dataset,
|
||||
num_replicas=DistributedInterface().get_world_size("dp"),
|
||||
rank=DistributedInterface().get_rank("dp"),
|
||||
num_replicas=DistributedInterface().get_world_size(Dim.DP),
|
||||
rank=DistributedInterface().get_rank(Dim.DP),
|
||||
shuffle=True,
|
||||
seed=0,
|
||||
drop_last=self.drop_last,
|
||||
@@ -142,6 +141,7 @@ class BatchGenerator(Iterator):
|
||||
num_workers=self.batching_workers,
|
||||
collate_fn=self.renderer.process_samples,
|
||||
pin_memory=self.pin_memory,
|
||||
pin_memory_device=DistributedInterface().current_device.type,
|
||||
drop_last=self.drop_last,
|
||||
)
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
|
||||
Reference in New Issue
Block a user