This commit is contained in:
hiyouga
2023-11-19 14:15:47 +08:00
parent ff6056405d
commit 1740131d63
8 changed files with 35 additions and 31 deletions

View File

@@ -66,8 +66,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
def get_current_device() -> str:
import accelerate
from accelerate import Accelerator
dummy_accelerator = Accelerator()
dummy_accelerator = accelerate.Accelerator()
if accelerate.utils.is_xpu_available():
return "xpu:{}".format(dummy_accelerator.local_process_index)
else: