mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-05 15:48:54 +08:00
[v1] add cli sampler (#9721)
This commit is contained in:
@@ -119,9 +119,19 @@ def synchronize() -> None:
|
||||
|
||||
|
||||
@requires_accelerator
|
||||
def set_device() -> None:
|
||||
"""Set current accelerator."""
|
||||
torch.accelerator.set_device_index(get_local_rank())
|
||||
def set_device_index() -> None:
|
||||
"""Set current accelerator index to local rank."""
|
||||
if get_current_accelerator().type != DeviceType.CPU:
|
||||
torch.accelerator.set_device_index(get_local_rank())
|
||||
|
||||
|
||||
@requires_accelerator
|
||||
def get_current_device() -> torch.device:
|
||||
"""Get current accelerator device."""
|
||||
if get_current_accelerator().type == DeviceType.CPU:
|
||||
return torch.device(DeviceType.CPU.value)
|
||||
else:
|
||||
return torch.device(type=get_current_accelerator().type, index=torch.accelerator.current_device_index())
|
||||
|
||||
|
||||
def is_torch_cuda_available():
|
||||
|
||||
Reference in New Issue
Block a user