diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 1eab538d..12a45445 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -22,11 +22,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": Dispatches a pre-trained model to GPUs with balanced memory. Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 """ - if ( - getattr(model, "is_loaded_in_8bit", False) # bnb - or getattr(model, "is_loaded_in_4bit", False) # bnb - or getattr(model.config, "quantization_config", None) # gptq or awq - ): # already set on current device + if getattr(model, "quantization_method", None): # already set on current device return model if torch.cuda.device_count() > 1: