diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index e1ae7d9f..10f5799e 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -72,7 +72,9 @@ def get_current_device() -> str: import accelerate if accelerate.utils.is_xpu_available(): return "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) - elif accelerate.utils.is_npu_available() or torch.cuda.is_available(): + elif accelerate.utils.is_npu_available(): + return "npu:{}".format(os.environ.get("LOCAL_RANK", "0")) + elif torch.cuda.is_available(): return "cuda:{}".format(os.environ.get("LOCAL_RANK", "0")) else: return "cpu" diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index 9a3fb3f6..53dfd6ea 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -87,7 +87,7 @@ def init_adapter( if is_trainable and checkpoint_to_resume is None: # create new lora weights while training if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": - target_modules = find_all_linear_modules(model, model_args.quantization_bit) + target_modules = find_all_linear_modules(model) else: target_modules = finetuning_args.lora_target diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 12a45445..a9138e7e 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -42,18 +42,18 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": return model.cuda() -def find_all_linear_modules( - model: "PreTrainedModel", - quantization_bit: Optional[int] = None -) -> List[str]: +def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: r""" Finds all available modules to apply lora. """ - if quantization_bit is not None: - import bitsandbytes as bnb - linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt - else: + quantization_method = getattr(model, "quantization_method", None) + if quantization_method is None: linear_cls = torch.nn.Linear + elif quantization_method == "bitsandbytes": + import bitsandbytes as bnb + linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt + else: + raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method)) output_layer_names = ["lm_head"] if model.config.model_type == "chatglm":