mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-03 18:25:59 +08:00
[v1] add cli sampler (#9721)
This commit is contained in:
@@ -123,12 +123,13 @@ class DistributedInterface:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
helper.set_device_index()
|
||||
self._is_distributed = helper.is_distributed()
|
||||
self._rank = helper.get_rank()
|
||||
self._world_size = helper.get_world_size()
|
||||
self._local_rank = helper.get_local_rank()
|
||||
self._local_world_size = helper.get_local_world_size()
|
||||
self.current_accelerator = helper.get_current_accelerator()
|
||||
self.current_device = helper.get_current_device()
|
||||
self.device_count = helper.get_device_count()
|
||||
|
||||
if config is None:
|
||||
@@ -144,15 +145,14 @@ class DistributedInterface:
|
||||
timeout = config.get("timeout", 18000)
|
||||
|
||||
if self._is_distributed:
|
||||
helper.set_device()
|
||||
init_process_group(timeout=timedelta(seconds=timeout))
|
||||
self.model_device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
device_type=self.current_device.type,
|
||||
mesh_shape=self.strategy.model_mesh_shape,
|
||||
mesh_dim_names=self.strategy.model_mesh_dim_names,
|
||||
)
|
||||
self.data_device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
device_type=self.current_device.type,
|
||||
mesh_shape=self.strategy.data_mesh_shape,
|
||||
mesh_dim_names=self.strategy.data_mesh_dim_names,
|
||||
)
|
||||
@@ -161,12 +161,12 @@ class DistributedInterface:
|
||||
self.data_device_mesh = None
|
||||
|
||||
self._initialized = True
|
||||
logger.info_rank0(f"DistributedInterface initialized with strategy={self.strategy}.")
|
||||
logger.info_rank0(f"DistributedInterface initialized: {self}.")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, "
|
||||
f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, "
|
||||
f"current_device={self.current_device}, rank={self._rank}, world_size={self._world_size}, "
|
||||
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
|
||||
)
|
||||
|
||||
@@ -251,4 +251,7 @@ class DistributedInterface:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(DistributedInterface(DistributedStrategy()))
|
||||
"""
|
||||
python -m llamafactory.v1.accelerator.interface
|
||||
"""
|
||||
print(DistributedInterface())
|
||||
|
||||
Reference in New Issue
Block a user