mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-08 04:35:58 +08:00
[v1] add sft (#9752)
This commit is contained in:
@@ -174,7 +174,7 @@ class DistributedInterface:
|
||||
"""Get device mesh for specified dimension."""
|
||||
if dim is None:
|
||||
raise ValueError("dim must be specified.")
|
||||
elif self.model_device_mesh is None:
|
||||
elif not self._is_distributed:
|
||||
return None
|
||||
elif dim in self.strategy.data_mesh_dim_names:
|
||||
return self.data_device_mesh[dim.value]
|
||||
@@ -183,14 +183,14 @@ class DistributedInterface:
|
||||
|
||||
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
|
||||
"""Get process group for specified dimension."""
|
||||
if self.model_device_mesh is None or dim is None:
|
||||
if not self._is_distributed or dim is None:
|
||||
return None
|
||||
else:
|
||||
return self.get_device_mesh(dim).get_group()
|
||||
|
||||
def get_rank(self, dim: Dim | None = None) -> int:
|
||||
"""Get parallel rank for specified dimension."""
|
||||
if self.model_device_mesh is None:
|
||||
if not self._is_distributed:
|
||||
return 0
|
||||
elif dim is None:
|
||||
return self._rank
|
||||
@@ -199,7 +199,7 @@ class DistributedInterface:
|
||||
|
||||
def get_world_size(self, dim: Dim | None = None) -> int:
|
||||
"""Get parallel size for specified dimension."""
|
||||
if self.model_device_mesh is None:
|
||||
if not self._is_distributed:
|
||||
return 1
|
||||
elif dim is None:
|
||||
return self._world_size
|
||||
@@ -216,7 +216,7 @@ class DistributedInterface:
|
||||
|
||||
def all_gather(self, data: TensorLike, dim: Dim | None = Dim.DP) -> TensorLike:
|
||||
"""Gather tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
if self._is_distributed:
|
||||
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
|
||||
else:
|
||||
return data
|
||||
@@ -225,29 +225,32 @@ class DistributedInterface:
|
||||
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
|
||||
) -> TensorLike:
|
||||
"""Reduce tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
if self._is_distributed:
|
||||
return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim))
|
||||
else:
|
||||
return data
|
||||
|
||||
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
|
||||
"""Broadcast tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
if self._is_distributed:
|
||||
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
|
||||
else:
|
||||
return data
|
||||
|
||||
def sync(self) -> None:
|
||||
"""Synchronize all processes."""
|
||||
helper.synchronize()
|
||||
if self._is_distributed:
|
||||
helper.synchronize()
|
||||
|
||||
def barrier(self) -> None:
|
||||
"""Barrier all processes."""
|
||||
barrier()
|
||||
if self._is_distributed:
|
||||
barrier()
|
||||
|
||||
def destroy(self) -> None:
|
||||
"""Destroy all processes."""
|
||||
destroy_process_group()
|
||||
if self._is_distributed:
|
||||
destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user