[v1] add sft (#9752)

This commit is contained in:
Yaowei Zheng
2026-01-12 03:15:01 +08:00
committed by GitHub
parent 4d3621e3d3
commit 958b9c3468
29 changed files with 439 additions and 305 deletions

View File

@@ -174,7 +174,7 @@ class DistributedInterface:
"""Get device mesh for specified dimension."""
if dim is None:
raise ValueError("dim must be specified.")
elif self.model_device_mesh is None:
elif not self._is_distributed:
return None
elif dim in self.strategy.data_mesh_dim_names:
return self.data_device_mesh[dim.value]
@@ -183,14 +183,14 @@ class DistributedInterface:
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
"""Get process group for specified dimension."""
if self.model_device_mesh is None or dim is None:
if not self._is_distributed or dim is None:
return None
else:
return self.get_device_mesh(dim).get_group()
def get_rank(self, dim: Dim | None = None) -> int:
"""Get parallel rank for specified dimension."""
if self.model_device_mesh is None:
if not self._is_distributed:
return 0
elif dim is None:
return self._rank
@@ -199,7 +199,7 @@ class DistributedInterface:
def get_world_size(self, dim: Dim | None = None) -> int:
"""Get parallel size for specified dimension."""
if self.model_device_mesh is None:
if not self._is_distributed:
return 1
elif dim is None:
return self._world_size
@@ -216,7 +216,7 @@ class DistributedInterface:
def all_gather(self, data: TensorLike, dim: Dim | None = Dim.DP) -> TensorLike:
"""Gather tensor across specified parallel group."""
if self.model_device_mesh is not None:
if self._is_distributed:
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
else:
return data
@@ -225,29 +225,32 @@ class DistributedInterface:
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
) -> TensorLike:
"""Reduce tensor across specified parallel group."""
if self.model_device_mesh is not None:
if self._is_distributed:
return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim))
else:
return data
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
"""Broadcast tensor across specified parallel group."""
if self.model_device_mesh is not None:
if self._is_distributed:
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
else:
return data
def sync(self) -> None:
"""Synchronize all processes."""
helper.synchronize()
if self._is_distributed:
helper.synchronize()
def barrier(self) -> None:
"""Barrier all processes."""
barrier()
if self._is_distributed:
barrier()
def destroy(self) -> None:
"""Destroy all processes."""
destroy_process_group()
if self._is_distributed:
destroy_process_group()
if __name__ == "__main__":