diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index 3203b4aa..317646e0 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -108,8 +108,11 @@ def configure_quantization( Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer) """ if getattr(config, "quantization_config", None): # ptq - if is_deepspeed_zero3_enabled(): - raise ValueError("DeepSpeed ZeRO-3 is incompatible with PTQ-quantized models.") + if model_args.quantization_bit is not None: + logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.") + + if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): + raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.") quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quant_method = quantization_config.get("quant_method", "") @@ -182,6 +185,9 @@ def configure_quantization( if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]: raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.") + if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): + raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.") + require_version("hqq", "To fix: pip install hqq") init_kwargs["quantization_config"] = HqqConfig( nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0 @@ -191,6 +197,9 @@ def configure_quantization( if model_args.quantization_bit != 8: raise ValueError("EETQ only accepts 8-bit quantization.") + if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): + raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.") + require_version("eetq", "To fix: pip install eetq") init_kwargs["quantization_config"] = EetqConfig() logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))