diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 10f5799e..2c659993 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -68,18 +68,6 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: return trainable_params, all_param -def get_current_device() -> str: - import accelerate - if accelerate.utils.is_xpu_available(): - return "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) - elif accelerate.utils.is_npu_available(): - return "npu:{}".format(os.environ.get("LOCAL_RANK", "0")) - elif torch.cuda.is_available(): - return "cuda:{}".format(os.environ.get("LOCAL_RANK", "0")) - else: - return "cpu" - - def get_logits_processor() -> "LogitsProcessorList": r""" Gets logits processor that removes NaN and Inf logits. diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 530869d5..e5075e37 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -1,3 +1,4 @@ +import os import math import torch from types import MethodType @@ -22,7 +23,7 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v from transformers.deepspeed import is_deepspeed_zero3_enabled from llmtuner.extras.logging import get_logger -from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype, try_download_model_from_ms +from llmtuner.extras.misc import count_parameters, infer_optim_dtype, try_download_model_from_ms from llmtuner.extras.packages import is_flash_attn2_available from llmtuner.extras.patches import llama_patch as LlamaPatches from llmtuner.hparams import FinetuningArguments @@ -150,7 +151,7 @@ def load_model_and_tokenizer( if getattr(config, "quantization_config", None): if model_args.quantization_bit is not None: # remove bnb quantization model_args.quantization_bit = None - config_kwargs["device_map"] = {"": get_current_device()} + config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} quantization_config = getattr(config, "quantization_config", None) logger.info("Loading {}-bit quantized model.".format(quantization_config.get("bits", -1))) @@ -172,7 +173,7 @@ def load_model_and_tokenizer( bnb_4bit_quant_type=model_args.quantization_type ) - config_kwargs["device_map"] = {"": get_current_device()} + config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) # Load pre-trained models (without valuehead)