mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-06 05:32:50 +08:00
parent
de9148930f
commit
f53bc7d9a0
@ -27,14 +27,12 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
|||||||
if (
|
if (
|
||||||
torch.cuda.device_count() > 1
|
torch.cuda.device_count() > 1
|
||||||
and isinstance(model, PreTrainedModel)
|
and isinstance(model, PreTrainedModel)
|
||||||
and getattr(model.config, "model_type", None) != "chatglm"
|
and model._no_split_modules is not None
|
||||||
|
and model.config.model_type != "chatglm"
|
||||||
):
|
):
|
||||||
from accelerate import dispatch_model
|
from accelerate import dispatch_model
|
||||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||||
|
|
||||||
if getattr(model, "_no_split_modules", None) is None:
|
|
||||||
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
|
||||||
|
|
||||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
|
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
|
||||||
max_memory = get_balanced_memory(model, **kwargs)
|
max_memory = get_balanced_memory(model, **kwargs)
|
||||||
# Make sure tied weights are tied before creating the device map.
|
# Make sure tied weights are tied before creating the device map.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user