mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[model] switch to gptqmodel (#8108)
This commit is contained in:
		
							parent
							
								
									bc7f00f2c7
								
							
						
					
					
						commit
						45030ff803
					
				@ -1,10 +1,10 @@
 | 
			
		||||
transformers>=4.45.0,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0
 | 
			
		||||
datasets>=2.16.0,<=3.5.0
 | 
			
		||||
accelerate>=0.34.0,<=1.6.0
 | 
			
		||||
peft>=0.14.0,<=0.15.1
 | 
			
		||||
datasets>=2.16.0,<=3.6.0
 | 
			
		||||
accelerate>=0.34.0,<=1.7.0
 | 
			
		||||
peft>=0.14.0,<=0.15.2
 | 
			
		||||
trl>=0.8.6,<=0.9.6
 | 
			
		||||
tokenizers>=0.19.0,<=0.21.1
 | 
			
		||||
gradio>=4.38.0,<=5.25.0
 | 
			
		||||
gradio>=4.38.0,<=5.29.1
 | 
			
		||||
scipy
 | 
			
		||||
einops
 | 
			
		||||
sentencepiece
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							@ -50,7 +50,7 @@ extra_require = {
 | 
			
		||||
    "bitsandbytes": ["bitsandbytes>=0.39.0"],
 | 
			
		||||
    "hqq": ["hqq"],
 | 
			
		||||
    "eetq": ["eetq"],
 | 
			
		||||
    "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
 | 
			
		||||
    "gptq": ["optimum>=1.24.0", "gptqmodel>=2.0.0"],
 | 
			
		||||
    "aqlm": ["aqlm[gpu]>=1.1.0"],
 | 
			
		||||
    "vllm": ["vllm>=0.4.3,<=0.8.5"],
 | 
			
		||||
    "sglang": ["sglang[srt]>=0.4.5", "transformers==4.51.1"],
 | 
			
		||||
 | 
			
		||||
@ -79,10 +79,15 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
 | 
			
		||||
        logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    if mandatory:
 | 
			
		||||
        hint = f"To fix: run `pip install {requirement}`."
 | 
			
		||||
    if "gptmodel" in requirement or "autoawq" in requirement:
 | 
			
		||||
        pip_command = f"pip install {requirement} --no-build-isolation"
 | 
			
		||||
    else:
 | 
			
		||||
        hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
 | 
			
		||||
        pip_command = f"pip install {requirement}"
 | 
			
		||||
 | 
			
		||||
    if mandatory:
 | 
			
		||||
        hint = f"To fix: run `{pip_command}`."
 | 
			
		||||
    else:
 | 
			
		||||
        hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
 | 
			
		||||
 | 
			
		||||
    require_version(requirement, hint)
 | 
			
		||||
 | 
			
		||||
@ -90,9 +95,9 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
 | 
			
		||||
def check_dependencies() -> None:
 | 
			
		||||
    r"""Check the version of the required packages."""
 | 
			
		||||
    check_version("transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
 | 
			
		||||
    check_version("datasets>=2.16.0,<=3.5.0")
 | 
			
		||||
    check_version("accelerate>=0.34.0,<=1.6.0")
 | 
			
		||||
    check_version("peft>=0.14.0,<=0.15.1")
 | 
			
		||||
    check_version("datasets>=2.16.0,<=3.6.0")
 | 
			
		||||
    check_version("accelerate>=0.34.0,<=1.7.0")
 | 
			
		||||
    check_version("peft>=0.14.0,<=0.15.2")
 | 
			
		||||
    check_version("trl>=0.8.6,<=0.9.6")
 | 
			
		||||
    if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
 | 
			
		||||
        logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
 | 
			
		||||
 | 
			
		||||
@ -148,7 +148,7 @@ def _check_extra_dependencies(
 | 
			
		||||
        check_version("mixture-of-depth>=1.1.6", mandatory=True)
 | 
			
		||||
 | 
			
		||||
    if model_args.infer_backend == EngineName.VLLM:
 | 
			
		||||
        check_version("vllm>=0.4.3,<=0.8.5")
 | 
			
		||||
        check_version("vllm>=0.4.3,<=0.8.6")
 | 
			
		||||
        check_version("vllm", mandatory=True)
 | 
			
		||||
    elif model_args.infer_backend == EngineName.SGLANG:
 | 
			
		||||
        check_version("sglang>=0.4.5")
 | 
			
		||||
 | 
			
		||||
@ -29,10 +29,8 @@ if TYPE_CHECKING:
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def configure_attn_implementation(
 | 
			
		||||
    config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
 | 
			
		||||
) -> None:
 | 
			
		||||
    if getattr(config, "model_type", None) == "gemma2" and is_trainable:
 | 
			
		||||
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
 | 
			
		||||
    if getattr(config, "model_type", None) == "gemma2":
 | 
			
		||||
        if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
 | 
			
		||||
            if is_flash_attn_2_available():
 | 
			
		||||
                if model_args.flash_attn != AttentionFunction.FA2:
 | 
			
		||||
 | 
			
		||||
@ -99,27 +99,29 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
 | 
			
		||||
    if not is_trainable or not model_args.moe_aux_loss_coef:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    model_type = getattr(config, "model_type", None)
 | 
			
		||||
    if model_args.moe_aux_loss_coef is not None:
 | 
			
		||||
        if model_type in [
 | 
			
		||||
            "dbrx",
 | 
			
		||||
            "granitemoe",
 | 
			
		||||
            "jamba",
 | 
			
		||||
            "jetmoe",
 | 
			
		||||
            "llama4",
 | 
			
		||||
            "mixtral",
 | 
			
		||||
            "olmoe",
 | 
			
		||||
            "phimoe",
 | 
			
		||||
            "qwen2_moe",
 | 
			
		||||
            "qwen3_moe",
 | 
			
		||||
        ]:
 | 
			
		||||
            setattr(config, "output_router_logits", is_trainable)
 | 
			
		||||
    if model_type in [
 | 
			
		||||
        "dbrx",
 | 
			
		||||
        "granitemoe",
 | 
			
		||||
        "jamba",
 | 
			
		||||
        "jetmoe",
 | 
			
		||||
        "llama4",
 | 
			
		||||
        "mixtral",
 | 
			
		||||
        "olmoe",
 | 
			
		||||
        "phimoe",
 | 
			
		||||
        "qwen2_moe",
 | 
			
		||||
        "qwen3_moe",
 | 
			
		||||
    ]:
 | 
			
		||||
        setattr(config, "output_router_logits", True)
 | 
			
		||||
 | 
			
		||||
        if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
 | 
			
		||||
            setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
 | 
			
		||||
    if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
 | 
			
		||||
        setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
 | 
			
		||||
 | 
			
		||||
        elif model_type == "deepseek":
 | 
			
		||||
            setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
 | 
			
		||||
    elif model_type == "deepseek":
 | 
			
		||||
        setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
 | 
			
		||||
 | 
			
		||||
        elif model_type == "jetmoe":
 | 
			
		||||
            setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
 | 
			
		||||
    elif model_type == "jetmoe":
 | 
			
		||||
        setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
 | 
			
		||||
 | 
			
		||||
@ -97,7 +97,7 @@ def configure_quantization(
 | 
			
		||||
        quant_method = quantization_config.get("quant_method", "")
 | 
			
		||||
 | 
			
		||||
        if quant_method == QuantizationMethod.GPTQ:
 | 
			
		||||
            check_version("auto_gptq>=0.5.0", mandatory=True)
 | 
			
		||||
            check_version("gptqmodel>=2.0.0", mandatory=True)
 | 
			
		||||
            quantization_config.pop("disable_exllama", None)  # remove deprecated args
 | 
			
		||||
            quantization_config["use_exllama"] = False  # disable exllama
 | 
			
		||||
 | 
			
		||||
@ -111,12 +111,12 @@ def configure_quantization(
 | 
			
		||||
        quant_bits = quantization_config.get("bits", "?")
 | 
			
		||||
        logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
 | 
			
		||||
 | 
			
		||||
    elif model_args.export_quantization_bit is not None:  # auto-gptq
 | 
			
		||||
    elif model_args.export_quantization_bit is not None:  # gptqmodel
 | 
			
		||||
        if model_args.export_quantization_bit not in [8, 4, 3, 2]:
 | 
			
		||||
            raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
 | 
			
		||||
 | 
			
		||||
        check_version("optimum>=1.17.0", mandatory=True)
 | 
			
		||||
        check_version("auto_gptq>=0.5.0", mandatory=True)
 | 
			
		||||
        check_version("optimum>=1.24.0", mandatory=True)
 | 
			
		||||
        check_version("gptqmodel>=2.0.0", mandatory=True)
 | 
			
		||||
        from accelerate.utils import get_max_memory
 | 
			
		||||
 | 
			
		||||
        if getattr(config, "model_type", None) == "chatglm":
 | 
			
		||||
@ -142,7 +142,8 @@ def configure_quantization(
 | 
			
		||||
        )
 | 
			
		||||
        init_kwargs["device_map"] = "auto"
 | 
			
		||||
        init_kwargs["max_memory"] = get_max_memory()
 | 
			
		||||
        logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
 | 
			
		||||
        model_args.compute_dtype = torch.float16  # force fp16 for gptqmodel
 | 
			
		||||
        logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with GPTQModel.")
 | 
			
		||||
 | 
			
		||||
    elif model_args.quantization_bit is not None:  # on-the-fly
 | 
			
		||||
        if model_args.quantization_method == QuantizationMethod.BNB:
 | 
			
		||||
 | 
			
		||||
@ -32,7 +32,7 @@ if TYPE_CHECKING:
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
 | 
			
		||||
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
 | 
			
		||||
    if model_args.rope_scaling is None:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
@ -40,30 +40,40 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
 | 
			
		||||
        logger.warning_rank0("Current model does not support RoPE scaling.")
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    rope_kwargs = {"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling)}  # handle enum
 | 
			
		||||
    if model_args.model_max_length is not None:
 | 
			
		||||
        if is_trainable and model_args.rope_scaling == RopeScaling.DYNAMIC:
 | 
			
		||||
    if hasattr(config, "max_position_embeddings"):
 | 
			
		||||
        old_max_length = getattr(config, "max_position_embeddings", None)
 | 
			
		||||
    else:
 | 
			
		||||
        logger.warning_rank0("Cannot find the max position embeddings in the config.")
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    if model_args.model_max_length is not None:  # training
 | 
			
		||||
        if model_args.model_max_length <= old_max_length:
 | 
			
		||||
            logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.")
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        if model_args.rope_scaling == RopeScaling.DYNAMIC:
 | 
			
		||||
            logger.warning_rank0(
 | 
			
		||||
                "Dynamic NTK scaling may not work well with fine-tuning. "
 | 
			
		||||
                "See: https://github.com/huggingface/transformers/pull/24653"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        current_max_length = getattr(config, "max_position_embeddings", None)
 | 
			
		||||
        if (not current_max_length) or model_args.model_max_length <= current_max_length:
 | 
			
		||||
            logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.")
 | 
			
		||||
            return
 | 
			
		||||
        rope_factor = float(math.ceil(model_args.model_max_length / old_max_length))
 | 
			
		||||
    else:  # inference
 | 
			
		||||
        rope_factor = 2.0
 | 
			
		||||
 | 
			
		||||
        logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
 | 
			
		||||
        setattr(config, "max_position_embeddings", model_args.model_max_length)
 | 
			
		||||
        rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
 | 
			
		||||
        if model_args.rope_scaling in [RopeScaling.DYNAMIC, RopeScaling.YARN]:
 | 
			
		||||
            rope_kwargs["original_max_position_embeddings"] = current_max_length
 | 
			
		||||
        elif model_args.rope_scaling == RopeScaling.LLAMA3:
 | 
			
		||||
            rope_kwargs["original_max_position_embeddings"] = current_max_length
 | 
			
		||||
            rope_kwargs["low_freq_factor"] = 1.0
 | 
			
		||||
            rope_kwargs["high_freq_factor"] = 4.0
 | 
			
		||||
    else:
 | 
			
		||||
        rope_kwargs["factor"] = 2.0
 | 
			
		||||
    rope_kwargs = {
 | 
			
		||||
        "rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling),  # handle enum
 | 
			
		||||
        "factor": rope_factor,
 | 
			
		||||
    }
 | 
			
		||||
    setattr(config, "max_position_embeddings", old_max_length * rope_factor)
 | 
			
		||||
    logger.info_rank0(f"Enlarge max model length from {old_max_length} to {old_max_length * rope_factor}.")
 | 
			
		||||
 | 
			
		||||
    if model_args.rope_scaling in [RopeScaling.DYNAMIC, RopeScaling.YARN]:
 | 
			
		||||
        rope_kwargs["original_max_position_embeddings"] = old_max_length
 | 
			
		||||
    elif model_args.rope_scaling == RopeScaling.LLAMA3:
 | 
			
		||||
        rope_kwargs["original_max_position_embeddings"] = old_max_length
 | 
			
		||||
        rope_kwargs["low_freq_factor"] = 1.0
 | 
			
		||||
        rope_kwargs["high_freq_factor"] = 4.0
 | 
			
		||||
 | 
			
		||||
    setattr(config, "rope_scaling", rope_kwargs)
 | 
			
		||||
    logger.info_rank0(
 | 
			
		||||
 | 
			
		||||
@ -102,8 +102,8 @@ def patch_config(
 | 
			
		||||
        else:
 | 
			
		||||
            model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
 | 
			
		||||
 | 
			
		||||
    configure_attn_implementation(config, model_args, is_trainable)
 | 
			
		||||
    configure_rope(config, model_args, is_trainable)
 | 
			
		||||
    configure_attn_implementation(config, model_args)
 | 
			
		||||
    configure_rope(config, model_args)
 | 
			
		||||
    configure_longlora(config, model_args, is_trainable)
 | 
			
		||||
    configure_quantization(config, tokenizer, model_args, init_kwargs)
 | 
			
		||||
    configure_moe(config, model_args, is_trainable)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user