From f5d739b132a9ef111647ba2a507115653ad1f935 Mon Sep 17 00:00:00 2001 From: sunyi0505 <1659275352@qq.com> Date: Tue, 21 Apr 2026 10:54:54 +0800 Subject: [PATCH] [v1] fix device mesh and clip_grad_norm for ulysses cp (#10366) --- .../trainer_plugins/distributed/fsdp2.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py index eb905263a..ddc9fb7f8 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py @@ -345,8 +345,32 @@ class FSDP2Engine: else: model = self.prepare_model(model) + self._warmup_grad_norm(model) + return model + def _warmup_grad_norm(self, model: HFModel) -> None: + """Warmup grad norm computation to initialize NCCL communication groups.""" + if self.fsdp_mesh is None: + return + + logger.info_rank0("Warming up grad norm computation...") + + for param in model.parameters(): + if param.requires_grad: + param.grad = torch.zeros_like(param) + + with torch.no_grad(): + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + if isinstance(grad_norm, torch.distributed._tensor.DTensor): + grad_norm = grad_norm.full_tensor() + + for param in model.parameters(): + if param.requires_grad: + param.grad = None + + logger.info_rank0("Grad norm warmup completed.") + def _load_from_dcp(self, model: HFModel, dcp_path: str): import torch.distributed.checkpoint as dcp