From 2279b1948ea749ed6c2ac92154dc8d62861c343b Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 3 Dec 2023 11:33:12 +0800 Subject: [PATCH] fix #1707 #1710 Former-commit-id: 03d05991f81b51826d3a4d9da214504e19a301bd --- src/llmtuner/extras/misc.py | 12 ------------ src/llmtuner/model/loader.py | 7 ++++--- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 10f5799e..2c659993 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -68,18 +68,6 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: return trainable_params, all_param -def get_current_device() -> str: - import accelerate - if accelerate.utils.is_xpu_available(): - return "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) - elif accelerate.utils.is_npu_available(): - return "npu:{}".format(os.environ.get("LOCAL_RANK", "0")) - elif torch.cuda.is_available(): - return "cuda:{}".format(os.environ.get("LOCAL_RANK", "0")) - else: - return "cpu" - - def get_logits_processor() -> "LogitsProcessorList": r""" Gets logits processor that removes NaN and Inf logits. diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 530869d5..e5075e37 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -1,3 +1,4 @@ +import os import math import torch from types import MethodType @@ -22,7 +23,7 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v from transformers.deepspeed import is_deepspeed_zero3_enabled from llmtuner.extras.logging import get_logger -from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype, try_download_model_from_ms +from llmtuner.extras.misc import count_parameters, infer_optim_dtype, try_download_model_from_ms from llmtuner.extras.packages import is_flash_attn2_available from llmtuner.extras.patches import llama_patch as LlamaPatches from llmtuner.hparams import FinetuningArguments @@ -150,7 +151,7 @@ def load_model_and_tokenizer( if getattr(config, "quantization_config", None): if model_args.quantization_bit is not None: # remove bnb quantization model_args.quantization_bit = None - config_kwargs["device_map"] = {"": get_current_device()} + config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} quantization_config = getattr(config, "quantization_config", None) logger.info("Loading {}-bit quantized model.".format(quantization_config.get("bits", -1))) @@ -172,7 +173,7 @@ def load_model_and_tokenizer( bnb_4bit_quant_type=model_args.quantization_type ) - config_kwargs["device_map"] = {"": get_current_device()} + config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) # Load pre-trained models (without valuehead)