[model] switch to gptqmodel (#8108)

This commit is contained in:
hoshi-hiyouga
2025-05-19 22:25:40 +08:00
committed by GitHub
parent bc7f00f2c7
commit 45030ff803
9 changed files with 78 additions and 62 deletions

View File

@@ -97,7 +97,7 @@ def configure_quantization(
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
check_version("auto_gptq>=0.5.0", mandatory=True)
check_version("gptqmodel>=2.0.0", mandatory=True)
quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama
@@ -111,12 +111,12 @@ def configure_quantization(
quant_bits = quantization_config.get("bits", "?")
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
elif model_args.export_quantization_bit is not None: # auto-gptq
elif model_args.export_quantization_bit is not None: # gptqmodel
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
check_version("optimum>=1.17.0", mandatory=True)
check_version("auto_gptq>=0.5.0", mandatory=True)
check_version("optimum>=1.24.0", mandatory=True)
check_version("gptqmodel>=2.0.0", mandatory=True)
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
@@ -142,7 +142,8 @@ def configure_quantization(
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
model_args.compute_dtype = torch.float16 # force fp16 for gptqmodel
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with GPTQModel.")
elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BNB: