[v1] fix device mesh and clip_grad_norm for ulysses cp (#10366)

This commit is contained in:
sunyi0505
2026-04-21 10:54:54 +08:00
committed by GitHub
parent c4bbac49b2
commit f5d739b132

View File

@@ -345,8 +345,32 @@ class FSDP2Engine:
else: else:
model = self.prepare_model(model) model = self.prepare_model(model)
self._warmup_grad_norm(model)
return model return model
def _warmup_grad_norm(self, model: HFModel) -> None:
"""Warmup grad norm computation to initialize NCCL communication groups."""
if self.fsdp_mesh is None:
return
logger.info_rank0("Warming up grad norm computation...")
for param in model.parameters():
if param.requires_grad:
param.grad = torch.zeros_like(param)
with torch.no_grad():
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
grad_norm = grad_norm.full_tensor()
for param in model.parameters():
if param.requires_grad:
param.grad = None
logger.info_rank0("Grad norm warmup completed.")
def _load_from_dcp(self, model: HFModel, dcp_path: str): def _load_from_dcp(self, model: HFModel, dcp_path: str):
import torch.distributed.checkpoint as dcp import torch.distributed.checkpoint as dcp