fix rm server

This commit is contained in:
hiyouga
2024-01-03 15:30:46 +08:00
parent 3014e3c189
commit 55021097d5
3 changed files with 4 additions and 2 deletions

View File

@@ -27,7 +27,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
if getattr(model, "_no_split_modules", None) is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}