mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-27 09:10:35 +08:00
[feat] Models trained and inferred with Mxfp4 are dequantized by default (#9652)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import BitsAndBytesConfig, EetqConfig, FineGrainedFP8Config, GPTQConfig, HqqConfig
|
||||
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
|
||||
@@ -94,10 +94,27 @@ def configure_quantization(
|
||||
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
|
||||
if quant_method != QuantizationMethod.MXFP4 and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
||||
if (
|
||||
quant_method not in (QuantizationMethod.MXFP4 and QuantizationMethod.FP8)
|
||||
and (is_deepspeed_zero3_enabled() or is_fsdp_enabled())
|
||||
):
|
||||
# mxfp4 will dequant the model weights
|
||||
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
||||
|
||||
if quant_method == QuantizationMethod.MXFP4:
|
||||
from transformers import Mxfp4Config
|
||||
|
||||
quant_config = Mxfp4Config(dequantize=True)
|
||||
init_kwargs["quantization_config"] = quant_config
|
||||
init_kwargs["ignore_mismatched_sizes"] = True
|
||||
|
||||
if quant_method == QuantizationMethod.FP8:
|
||||
from transformers import FineGrainedFP8Config
|
||||
|
||||
quant_config = FineGrainedFP8Config(dequantize=True)
|
||||
init_kwargs["quantization_config"] = quant_config
|
||||
init_kwargs["ignore_mismatched_sizes"] = True
|
||||
|
||||
if quant_method == QuantizationMethod.GPTQ:
|
||||
check_version("gptqmodel>=2.0.0", mandatory=True)
|
||||
quantization_config.pop("disable_exllama", None) # remove deprecated args
|
||||
@@ -110,11 +127,6 @@ def configure_quantization(
|
||||
check_version("aqlm>=1.1.0", mandatory=True)
|
||||
quantization_config["bits"] = 2
|
||||
|
||||
if quant_method == QuantizationMethod.FP8:
|
||||
quant_config = FineGrainedFP8Config(dequantize=True)
|
||||
init_kwargs["quantization_config"] = quant_config
|
||||
init_kwargs["ignore_mismatched_sizes"] = True
|
||||
|
||||
quant_bits = quantization_config.get("bits", "?")
|
||||
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user