mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
tiny fix
Former-commit-id: a973ce6e890d3f384fd225334f53a49907fff10d
This commit is contained in:
parent
e0da912f8e
commit
c8eff09c7c
@ -22,11 +22,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
||||
"""
|
||||
if (
|
||||
getattr(model, "is_loaded_in_8bit", False) # bnb
|
||||
or getattr(model, "is_loaded_in_4bit", False) # bnb
|
||||
or getattr(model.config, "quantization_config", None) # gptq or awq
|
||||
): # already set on current device
|
||||
if getattr(model, "quantization_method", None): # already set on current device
|
||||
return model
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
|
Loading…
x
Reference in New Issue
Block a user