From 7b45f5068ff9ad49f036e4b7a43e311007c91b85 Mon Sep 17 00:00:00 2001 From: billvsme <994171686@qq.com> Date: Thu, 30 Nov 2023 22:40:35 +0800 Subject: [PATCH] improve get_current_device Former-commit-id: 2b07815e7fc8dc6ad0a7e9eccdd6681fbab35f3c --- src/llmtuner/extras/misc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 672110cf..2e8d16a8 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -69,11 +69,11 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: def get_current_device() -> str: import accelerate - dummy_accelerator = accelerate.Accelerator() + local_rank = int(os.environ.get('LOCAL_RANK', '0')) if accelerate.utils.is_xpu_available(): - return "xpu:{}".format(dummy_accelerator.local_process_index) + return "xpu:{}".format(local_rank) else: - return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu" + return local_rank if torch.cuda.is_available() else "cpu" def get_logits_processor() -> "LogitsProcessorList":