mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
parent
f8376b228a
commit
6720189f3f
@ -25,7 +25,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
|||||||
if getattr(model, "quantization_method", None): # already set on current device
|
if getattr(model, "quantization_method", None): # already set on current device
|
||||||
return model
|
return model
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1 and getattr(model.config, "model_type", None) != "chatglm":
|
||||||
from accelerate import dispatch_model
|
from accelerate import dispatch_model
|
||||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user