improve get_current_device

Former-commit-id: 40dfcbc3d4571ce022b6aa39db581c8b88a75b8d
This commit is contained in:
billvsme 2023-11-30 22:40:35 +08:00
parent 3d291a82d3
commit e400f2e8ad

View File

@ -69,11 +69,11 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
def get_current_device() -> str: def get_current_device() -> str:
import accelerate import accelerate
dummy_accelerator = accelerate.Accelerator() local_rank = int(os.environ.get('LOCAL_RANK', '0'))
if accelerate.utils.is_xpu_available(): if accelerate.utils.is_xpu_available():
return "xpu:{}".format(dummy_accelerator.local_process_index) return "xpu:{}".format(local_rank)
else: else:
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu" return local_rank if torch.cuda.is_available() else "cpu"
def get_logits_processor() -> "LogitsProcessorList": def get_logits_processor() -> "LogitsProcessorList":