mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-21 20:36:02 +08:00
[v1] fix device mesh and clip_grad_norm for ulysses cp (#10366)
This commit is contained in:
@@ -345,8 +345,32 @@ class FSDP2Engine:
|
|||||||
else:
|
else:
|
||||||
model = self.prepare_model(model)
|
model = self.prepare_model(model)
|
||||||
|
|
||||||
|
self._warmup_grad_norm(model)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def _warmup_grad_norm(self, model: HFModel) -> None:
|
||||||
|
"""Warmup grad norm computation to initialize NCCL communication groups."""
|
||||||
|
if self.fsdp_mesh is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info_rank0("Warming up grad norm computation...")
|
||||||
|
|
||||||
|
for param in model.parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
param.grad = torch.zeros_like(param)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
|
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
|
||||||
|
grad_norm = grad_norm.full_tensor()
|
||||||
|
|
||||||
|
for param in model.parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
param.grad = None
|
||||||
|
|
||||||
|
logger.info_rank0("Grad norm warmup completed.")
|
||||||
|
|
||||||
def _load_from_dcp(self, model: HFModel, dcp_path: str):
|
def _load_from_dcp(self, model: HFModel, dcp_path: str):
|
||||||
import torch.distributed.checkpoint as dcp
|
import torch.distributed.checkpoint as dcp
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user