mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-05 15:48:54 +08:00
[v1] support training with fsdp2 (#9773)
Co-authored-by: frozenleaves <frozen@Mac.local> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -180,6 +180,16 @@ def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs)
|
||||
return result.tolist()
|
||||
|
||||
|
||||
def get_process_group_backend() -> str:
|
||||
"""Get backend for init process group."""
|
||||
if get_current_accelerator().type == DeviceType.NPU:
|
||||
return "hccl"
|
||||
elif get_current_accelerator().type == DeviceType.CUDA:
|
||||
return "nccl"
|
||||
else:
|
||||
return "gloo"
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
"""Gathers the tensor from all ranks and stacks them at the first dim."""
|
||||
world_size = get_world_size()
|
||||
|
||||
Reference in New Issue
Block a user