mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 20:22:49 +08:00
fix gptq model inference
Former-commit-id: 01e6c539b08e45c18878deb99550099d62f7645b
This commit is contained in:
parent
5d7c97b4de
commit
e0da912f8e
@ -146,22 +146,25 @@ def load_model_and_tokenizer(
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
# Quantization configurations (using gptq or awq)
|
||||
if getattr(config, "quantization_config", None):
|
||||
if model_args.quantization_bit is not None: # remove bnb quantization
|
||||
model_args.quantization_bit = None
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
quantization_config = getattr(config, "quantization_config", None)
|
||||
logger.info("Loading {}-bit quantized model.".format(quantization_config.get("bits", -1)))
|
||||
|
||||
# Quantization configurations (using bitsandbytes library)
|
||||
if model_args.quantization_bit is not None:
|
||||
if getattr(config, "quantization_config", None):
|
||||
raise ValueError("Remove `quantization_bit` if you are using a quantized model.")
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
config_kwargs["load_in_8bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
if model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
config_kwargs["load_in_4bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
|
@ -22,7 +22,11 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
||||
"""
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
|
||||
if (
|
||||
getattr(model, "is_loaded_in_8bit", False) # bnb
|
||||
or getattr(model, "is_loaded_in_4bit", False) # bnb
|
||||
or getattr(model.config, "quantization_config", None) # gptq or awq
|
||||
): # already set on current device
|
||||
return model
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
|
Loading…
x
Reference in New Issue
Block a user