mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-27 09:10:35 +08:00
[v1] model loader (#9613)
This commit is contained in:
@@ -19,17 +19,13 @@ import os
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, unique
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from ..utils.types import Tensor, TensorLike
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import ProcessGroup
|
||||
from ..utils.types import ProcessGroup, Tensor, TensorLike
|
||||
|
||||
|
||||
@unique
|
||||
@@ -107,7 +103,7 @@ def is_torch_xpu_available():
|
||||
return get_current_accelerator().type == DeviceType.XPU
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor:
|
||||
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
"""Gathers the tensor from all ranks and concats them along the first dim."""
|
||||
world_size = get_world_size()
|
||||
device = get_current_accelerator()
|
||||
@@ -116,7 +112,7 @@ def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor
|
||||
return output_tensor.view(-1, *tensor.size()[1:])
|
||||
|
||||
|
||||
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["ProcessGroup"] = None) -> TensorLike:
|
||||
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> TensorLike:
|
||||
"""Performs all reduce in the given process group."""
|
||||
device = get_current_accelerator()
|
||||
is_ndarray = isinstance(data, np.ndarray)
|
||||
|
||||
Reference in New Issue
Block a user