[feat] Models trained and inferred with Mxfp4 are dequantized by default (#9652)

Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Xunpeng Xiao
2025-12-24 00:26:40 +08:00
committed by GitHub
parent 84485406b7
commit 6a2eafbae3

View File

@@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any
import torch
from datasets import load_dataset
from transformers import BitsAndBytesConfig, EetqConfig, FineGrainedFP8Config, GPTQConfig, HqqConfig
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
@@ -94,10 +94,27 @@ def configure_quantization(
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method != QuantizationMethod.MXFP4 and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
if (
quant_method not in (QuantizationMethod.MXFP4 and QuantizationMethod.FP8)
and (is_deepspeed_zero3_enabled() or is_fsdp_enabled())
):
# mxfp4 will dequant the model weights
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
if quant_method == QuantizationMethod.MXFP4:
from transformers import Mxfp4Config
quant_config = Mxfp4Config(dequantize=True)
init_kwargs["quantization_config"] = quant_config
init_kwargs["ignore_mismatched_sizes"] = True
if quant_method == QuantizationMethod.FP8:
from transformers import FineGrainedFP8Config
quant_config = FineGrainedFP8Config(dequantize=True)
init_kwargs["quantization_config"] = quant_config
init_kwargs["ignore_mismatched_sizes"] = True
if quant_method == QuantizationMethod.GPTQ:
check_version("gptqmodel>=2.0.0", mandatory=True)
quantization_config.pop("disable_exllama", None) # remove deprecated args
@@ -110,11 +127,6 @@ def configure_quantization(
check_version("aqlm>=1.1.0", mandatory=True)
quantization_config["bits"] = 2
if quant_method == QuantizationMethod.FP8:
quant_config = FineGrainedFP8Config(dequantize=True)
init_kwargs["quantization_config"] = quant_config
init_kwargs["ignore_mismatched_sizes"] = True
quant_bits = quantization_config.get("bits", "?")
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")