mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-08 12:46:06 +08:00
[v1] model loader (#9613)
This commit is contained in:
@@ -19,17 +19,13 @@ import os
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, unique
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from ..utils.types import Tensor, TensorLike
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import ProcessGroup
|
||||
from ..utils.types import ProcessGroup, Tensor, TensorLike
|
||||
|
||||
|
||||
@unique
|
||||
@@ -107,7 +103,7 @@ def is_torch_xpu_available():
|
||||
return get_current_accelerator().type == DeviceType.XPU
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor:
|
||||
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
"""Gathers the tensor from all ranks and concats them along the first dim."""
|
||||
world_size = get_world_size()
|
||||
device = get_current_accelerator()
|
||||
@@ -116,7 +112,7 @@ def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor
|
||||
return output_tensor.view(-1, *tensor.size()[1:])
|
||||
|
||||
|
||||
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["ProcessGroup"] = None) -> TensorLike:
|
||||
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> TensorLike:
|
||||
"""Performs all reduce in the given process group."""
|
||||
device = get_current_accelerator()
|
||||
is_ndarray = isinstance(data, np.ndarray)
|
||||
|
||||
@@ -16,12 +16,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch.distributed import init_process_group
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
|
||||
from ..utils.types import Tensor, TensorLike
|
||||
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
|
||||
from .helper import (
|
||||
ReduceOp,
|
||||
all_gather,
|
||||
@@ -35,10 +37,6 @@ from .helper import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class Dim(str, Enum):
|
||||
"""Dimension names."""
|
||||
|
||||
@@ -130,21 +128,33 @@ class DistributedInterface:
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, strategy: DistributedStrategy) -> None:
|
||||
def __init__(self, config: Optional[DistributedConfig] = None) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.strategy = strategy
|
||||
if config is None:
|
||||
self.strategy = DistributedStrategy()
|
||||
timeout = 18000
|
||||
else:
|
||||
self.strategy = DistributedStrategy(
|
||||
mp_replicate_size=config.get("mp_replicate_size", 1),
|
||||
mp_shard_size=config.get("mp_shard_size", None),
|
||||
dp_size=config.get("dp_size", None),
|
||||
cp_size=config.get("cp_size", 1),
|
||||
)
|
||||
timeout = config.get("timeout", 18000)
|
||||
|
||||
if self._is_distributed:
|
||||
init_process_group(timeout=timedelta(seconds=timeout))
|
||||
self.model_device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
mesh_shape=strategy.model_mesh_shape,
|
||||
mesh_dim_names=strategy.model_mesh_dim_names,
|
||||
mesh_shape=self.strategy.model_mesh_shape,
|
||||
mesh_dim_names=self.strategy.model_mesh_dim_names,
|
||||
)
|
||||
self.data_device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
mesh_shape=strategy.data_mesh_shape,
|
||||
mesh_dim_names=strategy.data_mesh_dim_names,
|
||||
mesh_shape=self.strategy.data_mesh_shape,
|
||||
mesh_dim_names=self.strategy.data_mesh_dim_names,
|
||||
)
|
||||
else:
|
||||
self.model_device_mesh = None
|
||||
@@ -172,7 +182,7 @@ class DistributedInterface:
|
||||
return cls.model_device_mesh[dim.value]
|
||||
|
||||
@classmethod
|
||||
def get_group(cls, dim: Optional[Dim] = None) -> Optional["ProcessGroup"]:
|
||||
def get_group(cls, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
|
||||
"""Get process group for specified dimension."""
|
||||
if cls.model_device_mesh is None or dim is None:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user