mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-21 20:36:02 +08:00
[v1] fix device mesh and clip_grad_norm for ulysses cp (#10366)
This commit is contained in:
@@ -345,8 +345,32 @@ class FSDP2Engine:
|
||||
else:
|
||||
model = self.prepare_model(model)
|
||||
|
||||
self._warmup_grad_norm(model)
|
||||
|
||||
return model
|
||||
|
||||
def _warmup_grad_norm(self, model: HFModel) -> None:
|
||||
"""Warmup grad norm computation to initialize NCCL communication groups."""
|
||||
if self.fsdp_mesh is None:
|
||||
return
|
||||
|
||||
logger.info_rank0("Warming up grad norm computation...")
|
||||
|
||||
for param in model.parameters():
|
||||
if param.requires_grad:
|
||||
param.grad = torch.zeros_like(param)
|
||||
|
||||
with torch.no_grad():
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
|
||||
grad_norm = grad_norm.full_tensor()
|
||||
|
||||
for param in model.parameters():
|
||||
if param.requires_grad:
|
||||
param.grad = None
|
||||
|
||||
logger.info_rank0("Grad norm warmup completed.")
|
||||
|
||||
def _load_from_dcp(self, model: HFModel, dcp_path: str):
|
||||
import torch.distributed.checkpoint as dcp
|
||||
|
||||
|
||||
Reference in New Issue
Block a user