diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py index eb905263a..ddc9fb7f8 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py @@ -345,8 +345,32 @@ class FSDP2Engine: else: model = self.prepare_model(model) + self._warmup_grad_norm(model) + return model + def _warmup_grad_norm(self, model: HFModel) -> None: + """Warmup grad norm computation to initialize NCCL communication groups.""" + if self.fsdp_mesh is None: + return + + logger.info_rank0("Warming up grad norm computation...") + + for param in model.parameters(): + if param.requires_grad: + param.grad = torch.zeros_like(param) + + with torch.no_grad(): + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + if isinstance(grad_norm, torch.distributed._tensor.DTensor): + grad_norm = grad_norm.full_tensor() + + for param in model.parameters(): + if param.requires_grad: + param.grad = None + + logger.info_rank0("Grad norm warmup completed.") + def _load_from_dcp(self, model: HFModel, dcp_path: str): import torch.distributed.checkpoint as dcp