diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 2e8d16a8..4f754c5c 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -69,11 +69,12 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: def get_current_device() -> str: import accelerate - local_rank = int(os.environ.get('LOCAL_RANK', '0')) if accelerate.utils.is_xpu_available(): - return "xpu:{}".format(local_rank) + return "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) + elif accelerate.utils.is_npu_available() or torch.cuda.is_available(): + return os.environ.get("LOCAL_RANK", "0") else: - return local_rank if torch.cuda.is_available() else "cpu" + return "cpu" def get_logits_processor() -> "LogitsProcessorList": diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index fc2c5c2c..34dd1804 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -44,8 +44,8 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]: def gen_cmd(args: Dict[str, Any]) -> str: args.pop("disable_tqdm", None) args["plot_loss"] = args.get("do_train", None) - cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES') or "0" - cmd_lines = [f"CUDA_VISIBLE_DEVICES={cuda_visible_devices} python src/train_bash.py "] + current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + cmd_lines = ["CUDA_VISIBLE_DEVICES={} python src/train_bash.py ".format(current_devices)] for k, v in args.items(): if v is not None and v != "": cmd_lines.append(" --{} {} ".format(k, str(v)))