diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index ebffbbc7f..59dfc502a 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -110,7 +110,7 @@ def configure_quantization( check_version("aqlm>=1.1.0", mandatory=True) quantization_config["bits"] = 2 - if quant_method == QuantizationMethod.FP8 and is_trainable: + if quant_method == QuantizationMethod.FP8: quant_config = FineGrainedFP8Config(dequantize=True) init_kwargs["quantization_config"] = quant_config