From a44ba7a2b82415272566e0460da1535115e5f428 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 1 Dec 2023 15:58:50 +0800 Subject: [PATCH] tiny fix Former-commit-id: e597d3c084c8700e247bad6e26d2ee40fc3c316b --- src/llmtuner/extras/misc.py | 7 ++++--- src/llmtuner/webui/utils.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 2e8d16a8..4f754c5c 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -69,11 +69,12 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: def get_current_device() -> str: import accelerate - local_rank = int(os.environ.get('LOCAL_RANK', '0')) if accelerate.utils.is_xpu_available(): - return "xpu:{}".format(local_rank) + return "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) + elif accelerate.utils.is_npu_available() or torch.cuda.is_available(): + return os.environ.get("LOCAL_RANK", "0") else: - return local_rank if torch.cuda.is_available() else "cpu" + return "cpu" def get_logits_processor() -> "LogitsProcessorList": diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index fc2c5c2c..34dd1804 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -44,8 +44,8 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]: def gen_cmd(args: Dict[str, Any]) -> str: args.pop("disable_tqdm", None) args["plot_loss"] = args.get("do_train", None) - cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES') or "0" - cmd_lines = [f"CUDA_VISIBLE_DEVICES={cuda_visible_devices} python src/train_bash.py "] + current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + cmd_lines = ["CUDA_VISIBLE_DEVICES={} python src/train_bash.py ".format(current_devices)] for k, v in args.items(): if v is not None and v != "": cmd_lines.append(" --{} {} ".format(k, str(v)))