From f3fd67a9bb0322e0bc09bda31a9a0a8986d7438c Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 19 May 2025 22:25:40 +0800 Subject: [PATCH] [model] switch to gptqmodel (#8108) --- requirements.txt | 8 ++-- setup.py | 2 +- src/llamafactory/extras/misc.py | 17 ++++--- src/llamafactory/hparams/parser.py | 2 +- .../model/model_utils/attention.py | 6 +-- src/llamafactory/model/model_utils/moe.py | 42 ++++++++-------- .../model/model_utils/quantization.py | 11 +++-- src/llamafactory/model/model_utils/rope.py | 48 +++++++++++-------- src/llamafactory/model/patcher.py | 4 +- 9 files changed, 78 insertions(+), 62 deletions(-) diff --git a/requirements.txt b/requirements.txt index c56f9ac5..484ebae0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ transformers>=4.45.0,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0 -datasets>=2.16.0,<=3.5.0 -accelerate>=0.34.0,<=1.6.0 -peft>=0.14.0,<=0.15.1 +datasets>=2.16.0,<=3.6.0 +accelerate>=0.34.0,<=1.7.0 +peft>=0.14.0,<=0.15.2 trl>=0.8.6,<=0.9.6 tokenizers>=0.19.0,<=0.21.1 -gradio>=4.38.0,<=5.25.0 +gradio>=4.38.0,<=5.29.1 scipy einops sentencepiece diff --git a/setup.py b/setup.py index 2561d4a4..4c29ca87 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ extra_require = { "bitsandbytes": ["bitsandbytes>=0.39.0"], "hqq": ["hqq"], "eetq": ["eetq"], - "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], + "gptq": ["optimum>=1.24.0", "gptqmodel>=2.0.0"], "aqlm": ["aqlm[gpu]>=1.1.0"], "vllm": ["vllm>=0.4.3,<=0.8.5"], "sglang": ["sglang[srt]>=0.4.5", "transformers==4.51.1"], diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index e5b91256..dcc22c1b 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -79,10 +79,15 @@ def check_version(requirement: str, mandatory: bool = False) -> None: logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.") return - if mandatory: - hint = f"To fix: run `pip install {requirement}`." + if "gptmodel" in requirement or "autoawq" in requirement: + pip_command = f"pip install {requirement} --no-build-isolation" else: - hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check." + pip_command = f"pip install {requirement}" + + if mandatory: + hint = f"To fix: run `{pip_command}`." + else: + hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check." require_version(requirement, hint) @@ -90,9 +95,9 @@ def check_version(requirement: str, mandatory: bool = False) -> None: def check_dependencies() -> None: r"""Check the version of the required packages.""" check_version("transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0") - check_version("datasets>=2.16.0,<=3.5.0") - check_version("accelerate>=0.34.0,<=1.6.0") - check_version("peft>=0.14.0,<=0.15.1") + check_version("datasets>=2.16.0,<=3.6.0") + check_version("accelerate>=0.34.0,<=1.7.0") + check_version("peft>=0.14.0,<=0.15.2") check_version("trl>=0.8.6,<=0.9.6") if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"): logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.") diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 539b9152..7b0c0476 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -148,7 +148,7 @@ def _check_extra_dependencies( check_version("mixture-of-depth>=1.1.6", mandatory=True) if model_args.infer_backend == EngineName.VLLM: - check_version("vllm>=0.4.3,<=0.8.5") + check_version("vllm>=0.4.3,<=0.8.6") check_version("vllm", mandatory=True) elif model_args.infer_backend == EngineName.SGLANG: check_version("sglang>=0.4.5") diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index 860bf891..fb86a163 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -29,10 +29,8 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) -def configure_attn_implementation( - config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool -) -> None: - if getattr(config, "model_type", None) == "gemma2" and is_trainable: +def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None: + if getattr(config, "model_type", None) == "gemma2": if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2: if is_flash_attn_2_available(): if model_args.flash_attn != AttentionFunction.FA2: diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py index 51f289f4..bc517cd6 100644 --- a/src/llamafactory/model/model_utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -99,27 +99,29 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if not is_trainable or not model_args.moe_aux_loss_coef: + return + model_type = getattr(config, "model_type", None) - if model_args.moe_aux_loss_coef is not None: - if model_type in [ - "dbrx", - "granitemoe", - "jamba", - "jetmoe", - "llama4", - "mixtral", - "olmoe", - "phimoe", - "qwen2_moe", - "qwen3_moe", - ]: - setattr(config, "output_router_logits", is_trainable) + if model_type in [ + "dbrx", + "granitemoe", + "jamba", + "jetmoe", + "llama4", + "mixtral", + "olmoe", + "phimoe", + "qwen2_moe", + "qwen3_moe", + ]: + setattr(config, "output_router_logits", True) - if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]: - setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef) + if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]: + setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef) - elif model_type == "deepseek": - setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef) + elif model_type == "deepseek": + setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef) - elif model_type == "jetmoe": - setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef) + elif model_type == "jetmoe": + setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef) diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index cb288af1..d3160f69 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -97,7 +97,7 @@ def configure_quantization( quant_method = quantization_config.get("quant_method", "") if quant_method == QuantizationMethod.GPTQ: - check_version("auto_gptq>=0.5.0", mandatory=True) + check_version("gptqmodel>=2.0.0", mandatory=True) quantization_config.pop("disable_exllama", None) # remove deprecated args quantization_config["use_exllama"] = False # disable exllama @@ -111,12 +111,12 @@ def configure_quantization( quant_bits = quantization_config.get("bits", "?") logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.") - elif model_args.export_quantization_bit is not None: # auto-gptq + elif model_args.export_quantization_bit is not None: # gptqmodel if model_args.export_quantization_bit not in [8, 4, 3, 2]: raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.") - check_version("optimum>=1.17.0", mandatory=True) - check_version("auto_gptq>=0.5.0", mandatory=True) + check_version("optimum>=1.24.0", mandatory=True) + check_version("gptqmodel>=2.0.0", mandatory=True) from accelerate.utils import get_max_memory if getattr(config, "model_type", None) == "chatglm": @@ -142,7 +142,8 @@ def configure_quantization( ) init_kwargs["device_map"] = "auto" init_kwargs["max_memory"] = get_max_memory() - logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.") + model_args.compute_dtype = torch.float16 # force fp16 for gptqmodel + logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with GPTQModel.") elif model_args.quantization_bit is not None: # on-the-fly if model_args.quantization_method == QuantizationMethod.BNB: diff --git a/src/llamafactory/model/model_utils/rope.py b/src/llamafactory/model/model_utils/rope.py index 29b56a0b..d04279e0 100644 --- a/src/llamafactory/model/model_utils/rope.py +++ b/src/llamafactory/model/model_utils/rope.py @@ -32,7 +32,7 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) -def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: +def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments") -> None: if model_args.rope_scaling is None: return @@ -40,30 +40,40 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ logger.warning_rank0("Current model does not support RoPE scaling.") return - rope_kwargs = {"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling)} # handle enum - if model_args.model_max_length is not None: - if is_trainable and model_args.rope_scaling == RopeScaling.DYNAMIC: + if hasattr(config, "max_position_embeddings"): + old_max_length = getattr(config, "max_position_embeddings", None) + else: + logger.warning_rank0("Cannot find the max position embeddings in the config.") + return + + if model_args.model_max_length is not None: # training + if model_args.model_max_length <= old_max_length: + logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.") + return + + if model_args.rope_scaling == RopeScaling.DYNAMIC: logger.warning_rank0( "Dynamic NTK scaling may not work well with fine-tuning. " "See: https://github.com/huggingface/transformers/pull/24653" ) - current_max_length = getattr(config, "max_position_embeddings", None) - if (not current_max_length) or model_args.model_max_length <= current_max_length: - logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.") - return + rope_factor = float(math.ceil(model_args.model_max_length / old_max_length)) + else: # inference + rope_factor = 2.0 - logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.") - setattr(config, "max_position_embeddings", model_args.model_max_length) - rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length)) - if model_args.rope_scaling in [RopeScaling.DYNAMIC, RopeScaling.YARN]: - rope_kwargs["original_max_position_embeddings"] = current_max_length - elif model_args.rope_scaling == RopeScaling.LLAMA3: - rope_kwargs["original_max_position_embeddings"] = current_max_length - rope_kwargs["low_freq_factor"] = 1.0 - rope_kwargs["high_freq_factor"] = 4.0 - else: - rope_kwargs["factor"] = 2.0 + rope_kwargs = { + "rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling), # handle enum + "factor": rope_factor, + } + setattr(config, "max_position_embeddings", old_max_length * rope_factor) + logger.info_rank0(f"Enlarge max model length from {old_max_length} to {old_max_length * rope_factor}.") + + if model_args.rope_scaling in [RopeScaling.DYNAMIC, RopeScaling.YARN]: + rope_kwargs["original_max_position_embeddings"] = old_max_length + elif model_args.rope_scaling == RopeScaling.LLAMA3: + rope_kwargs["original_max_position_embeddings"] = old_max_length + rope_kwargs["low_freq_factor"] = 1.0 + rope_kwargs["high_freq_factor"] = 4.0 setattr(config, "rope_scaling", rope_kwargs) logger.info_rank0( diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index ce1a5a7d..cedcf9da 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -102,8 +102,8 @@ def patch_config( else: model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) - configure_attn_implementation(config, model_args, is_trainable) - configure_rope(config, model_args, is_trainable) + configure_attn_implementation(config, model_args) + configure_rope(config, model_args) configure_longlora(config, model_args, is_trainable) configure_quantization(config, tokenizer, model_args, init_kwargs) configure_moe(config, model_args, is_trainable)