diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index 8728ce3ba..417ab1112 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any import torch from datasets import load_dataset -from transformers import BitsAndBytesConfig, EetqConfig, FineGrainedFP8Config, GPTQConfig, HqqConfig +from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled @@ -94,10 +94,27 @@ def configure_quantization( quantization_config: dict[str, Any] = getattr(config, "quantization_config", None) quant_method = quantization_config.get("quant_method", "") - if quant_method != QuantizationMethod.MXFP4 and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()): + if ( + quant_method not in (QuantizationMethod.MXFP4 and QuantizationMethod.FP8) + and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()) + ): # mxfp4 will dequant the model weights raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.") + if quant_method == QuantizationMethod.MXFP4: + from transformers import Mxfp4Config + + quant_config = Mxfp4Config(dequantize=True) + init_kwargs["quantization_config"] = quant_config + init_kwargs["ignore_mismatched_sizes"] = True + + if quant_method == QuantizationMethod.FP8: + from transformers import FineGrainedFP8Config + + quant_config = FineGrainedFP8Config(dequantize=True) + init_kwargs["quantization_config"] = quant_config + init_kwargs["ignore_mismatched_sizes"] = True + if quant_method == QuantizationMethod.GPTQ: check_version("gptqmodel>=2.0.0", mandatory=True) quantization_config.pop("disable_exllama", None) # remove deprecated args @@ -110,11 +127,6 @@ def configure_quantization( check_version("aqlm>=1.1.0", mandatory=True) quantization_config["bits"] = 2 - if quant_method == QuantizationMethod.FP8: - quant_config = FineGrainedFP8Config(dequantize=True) - init_kwargs["quantization_config"] = quant_config - init_kwargs["ignore_mismatched_sizes"] = True - quant_bits = quantization_config.get("bits", "?") logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")