Former-commit-id: 11be28201f688ac21cf94135067d37e9aa7ab0a1
This commit is contained in:
hiyouga 2023-12-02 00:37:53 +08:00
parent 8ca196d51f
commit 5ea6a7c6d6

View File

@ -25,7 +25,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
if getattr(model, "quantization_method", None): # already set on current device
return model
if torch.cuda.device_count() > 1:
if torch.cuda.device_count() > 1 and getattr(model.config, "model_type", None) != "chatglm":
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory