mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-24 07:40:35 +08:00
[v1] model loader (#9613)
This commit is contained in:
@@ -16,12 +16,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch.distributed import init_process_group
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
|
||||
from ..utils.types import Tensor, TensorLike
|
||||
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
|
||||
from .helper import (
|
||||
ReduceOp,
|
||||
all_gather,
|
||||
@@ -35,10 +37,6 @@ from .helper import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class Dim(str, Enum):
|
||||
"""Dimension names."""
|
||||
|
||||
@@ -130,21 +128,33 @@ class DistributedInterface:
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, strategy: DistributedStrategy) -> None:
|
||||
def __init__(self, config: Optional[DistributedConfig] = None) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.strategy = strategy
|
||||
if config is None:
|
||||
self.strategy = DistributedStrategy()
|
||||
timeout = 18000
|
||||
else:
|
||||
self.strategy = DistributedStrategy(
|
||||
mp_replicate_size=config.get("mp_replicate_size", 1),
|
||||
mp_shard_size=config.get("mp_shard_size", None),
|
||||
dp_size=config.get("dp_size", None),
|
||||
cp_size=config.get("cp_size", 1),
|
||||
)
|
||||
timeout = config.get("timeout", 18000)
|
||||
|
||||
if self._is_distributed:
|
||||
init_process_group(timeout=timedelta(seconds=timeout))
|
||||
self.model_device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
mesh_shape=strategy.model_mesh_shape,
|
||||
mesh_dim_names=strategy.model_mesh_dim_names,
|
||||
mesh_shape=self.strategy.model_mesh_shape,
|
||||
mesh_dim_names=self.strategy.model_mesh_dim_names,
|
||||
)
|
||||
self.data_device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
mesh_shape=strategy.data_mesh_shape,
|
||||
mesh_dim_names=strategy.data_mesh_dim_names,
|
||||
mesh_shape=self.strategy.data_mesh_shape,
|
||||
mesh_dim_names=self.strategy.data_mesh_dim_names,
|
||||
)
|
||||
else:
|
||||
self.model_device_mesh = None
|
||||
@@ -172,7 +182,7 @@ class DistributedInterface:
|
||||
return cls.model_device_mesh[dim.value]
|
||||
|
||||
@classmethod
|
||||
def get_group(cls, dim: Optional[Dim] = None) -> Optional["ProcessGroup"]:
|
||||
def get_group(cls, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
|
||||
"""Get process group for specified dimension."""
|
||||
if cls.model_device_mesh is None or dim is None:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user