diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index a9138e7e..42bef35b 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -25,7 +25,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": if getattr(model, "quantization_method", None): # already set on current device return model - if torch.cuda.device_count() > 1: + if torch.cuda.device_count() > 1 and getattr(model.config, "model_type", None) != "chatglm": from accelerate import dispatch_model from accelerate.utils import infer_auto_device_map, get_balanced_memory