fix unusual output of 8bit models #278 #391

This commit is contained in:
hiyouga
2023-08-12 00:25:29 +08:00
parent a48cb0d474
commit dd51c24203
2 changed files with 4 additions and 1 deletions

View File

@@ -142,6 +142,9 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
return model
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory