diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index ffbf5825..8d888ccb 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -122,9 +122,22 @@ def configure_quantization( if getattr(config, "model_type", None) == "chatglm": raise ValueError("ChatGLM model is not supported yet.") + try: + from optimum.gptq import utils as gq_utils + if "language_model.model.layers" not in gq_utils.BLOCK_PATTERNS: + gq_utils.BLOCK_PATTERNS.insert(0, "language_model.model.layers") + except ImportError: + pass + + block_name_to_quantize = None + if getattr(config, "model_type", None) in ["gemma3", "paligemma"]: + block_name_to_quantize = "language_model.model.layers" + init_kwargs["quantization_config"] = GPTQConfig( bits=model_args.export_quantization_bit, + tokenizer=tokenizer, dataset=_get_quantization_dataset(tokenizer, model_args), + block_name_to_quantize=block_name_to_quantize, ) init_kwargs["device_map"] = "auto" init_kwargs["max_memory"] = get_max_memory()