[v1] model loader (#9613)

This commit is contained in:
Yaowei Zheng
2025-12-14 11:50:52 +08:00
committed by GitHub
parent fdd24276ed
commit aeda079014
27 changed files with 449 additions and 305 deletions

View File

@@ -19,17 +19,13 @@ import os
from contextlib import contextmanager
from enum import Enum, unique
from functools import lru_cache
from typing import TYPE_CHECKING, Optional
from typing import Optional
import numpy as np
import torch
import torch.distributed as dist
from ..utils.types import Tensor, TensorLike
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
from ..utils.types import ProcessGroup, Tensor, TensorLike
@unique
@@ -107,7 +103,7 @@ def is_torch_xpu_available():
return get_current_accelerator().type == DeviceType.XPU
def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor:
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
"""Gathers the tensor from all ranks and concats them along the first dim."""
world_size = get_world_size()
device = get_current_accelerator()
@@ -116,7 +112,7 @@ def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor
return output_tensor.view(-1, *tensor.size()[1:])
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["ProcessGroup"] = None) -> TensorLike:
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> TensorLike:
"""Performs all reduce in the given process group."""
device = get_current_accelerator()
is_ndarray = isinstance(data, np.ndarray)