Former-commit-id: a973ce6e890d3f384fd225334f53a49907fff10d
This commit is contained in:
hiyouga 2023-12-01 23:37:10 +08:00
parent e0da912f8e
commit c8eff09c7c

View File

@ -22,11 +22,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if (
getattr(model, "is_loaded_in_8bit", False) # bnb
or getattr(model, "is_loaded_in_4bit", False) # bnb
or getattr(model.config, "quantization_config", None) # gptq or awq
): # already set on current device
if getattr(model, "quantization_method", None): # already set on current device
return model
if torch.cuda.device_count() > 1: