add max_memory for gptq #1923

This commit is contained in:
hiyouga
2023-12-20 18:15:17 +08:00
parent 31165a9822
commit c4a3977ad7
4 changed files with 26 additions and 24 deletions

View File

@@ -76,6 +76,7 @@ def configure_quantization(
if finetuning_args.export_quantization_bit is not None: # gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported.")
@@ -86,6 +87,7 @@ def configure_quantization(
dataset=get_quantization_dataset(tokenizer, model_args, finetuning_args)
)
config_kwargs["device_map"] = "auto"
config_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(finetuning_args.export_quantization_bit))