[v1] model loader (#9613)

This commit is contained in:
Yaowei Zheng
2025-12-14 11:50:52 +08:00
committed by GitHub
parent fdd24276ed
commit aeda079014
27 changed files with 449 additions and 305 deletions

View File

@@ -16,12 +16,14 @@
# limitations under the License.
from dataclasses import dataclass
from datetime import timedelta
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional
from typing import Any, Optional
from torch.distributed import init_process_group
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils.types import Tensor, TensorLike
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
from .helper import (
ReduceOp,
all_gather,
@@ -35,10 +37,6 @@ from .helper import (
)
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
class Dim(str, Enum):
"""Dimension names."""
@@ -130,21 +128,33 @@ class DistributedInterface:
return cls._instance
def __init__(self, strategy: DistributedStrategy) -> None:
def __init__(self, config: Optional[DistributedConfig] = None) -> None:
if self._initialized:
return
self.strategy = strategy
if config is None:
self.strategy = DistributedStrategy()
timeout = 18000
else:
self.strategy = DistributedStrategy(
mp_replicate_size=config.get("mp_replicate_size", 1),
mp_shard_size=config.get("mp_shard_size", None),
dp_size=config.get("dp_size", None),
cp_size=config.get("cp_size", 1),
)
timeout = config.get("timeout", 18000)
if self._is_distributed:
init_process_group(timeout=timedelta(seconds=timeout))
self.model_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
mesh_shape=strategy.model_mesh_shape,
mesh_dim_names=strategy.model_mesh_dim_names,
mesh_shape=self.strategy.model_mesh_shape,
mesh_dim_names=self.strategy.model_mesh_dim_names,
)
self.data_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
mesh_shape=strategy.data_mesh_shape,
mesh_dim_names=strategy.data_mesh_dim_names,
mesh_shape=self.strategy.data_mesh_shape,
mesh_dim_names=self.strategy.data_mesh_dim_names,
)
else:
self.model_device_mesh = None
@@ -172,7 +182,7 @@ class DistributedInterface:
return cls.model_device_mesh[dim.value]
@classmethod
def get_group(cls, dim: Optional[Dim] = None) -> Optional["ProcessGroup"]:
def get_group(cls, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
"""Get process group for specified dimension."""
if cls.model_device_mesh is None or dim is None:
return None