[v1] add cli sampler (#9721)

This commit is contained in:
Yaowei Zheng
2026-01-06 23:31:27 +08:00
committed by GitHub
parent e944dc442c
commit ea0b4e2466
45 changed files with 1091 additions and 505 deletions

View File

@@ -123,12 +123,13 @@ class DistributedInterface:
if self._initialized:
return
helper.set_device_index()
self._is_distributed = helper.is_distributed()
self._rank = helper.get_rank()
self._world_size = helper.get_world_size()
self._local_rank = helper.get_local_rank()
self._local_world_size = helper.get_local_world_size()
self.current_accelerator = helper.get_current_accelerator()
self.current_device = helper.get_current_device()
self.device_count = helper.get_device_count()
if config is None:
@@ -144,15 +145,14 @@ class DistributedInterface:
timeout = config.get("timeout", 18000)
if self._is_distributed:
helper.set_device()
init_process_group(timeout=timedelta(seconds=timeout))
self.model_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
device_type=self.current_device.type,
mesh_shape=self.strategy.model_mesh_shape,
mesh_dim_names=self.strategy.model_mesh_dim_names,
)
self.data_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
device_type=self.current_device.type,
mesh_shape=self.strategy.data_mesh_shape,
mesh_dim_names=self.strategy.data_mesh_dim_names,
)
@@ -161,12 +161,12 @@ class DistributedInterface:
self.data_device_mesh = None
self._initialized = True
logger.info_rank0(f"DistributedInterface initialized with strategy={self.strategy}.")
logger.info_rank0(f"DistributedInterface initialized: {self}.")
def __str__(self) -> str:
return (
f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, "
f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, "
f"current_device={self.current_device}, rank={self._rank}, world_size={self._world_size}, "
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
)
@@ -251,4 +251,7 @@ class DistributedInterface:
if __name__ == "__main__":
print(DistributedInterface(DistributedStrategy()))
"""
python -m llamafactory.v1.accelerator.interface
"""
print(DistributedInterface())