diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 03e4e5a2..38887b70 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -50,10 +50,10 @@ def configure_quantization( r""" Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) """ - if is_deepspeed_zero3_enabled(): - raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") - if getattr(config, "quantization_config", None): # gptq or awq + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + config_kwargs["device_map"] = {"": get_current_device()} quantization_config = getattr(config, "quantization_config", None) logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1))) @@ -76,6 +76,9 @@ def configure_quantization( logger.info("Quantizing model to {} bit.".format(finetuning_args.export_quantization_bit)) elif model_args.quantization_bit is not None: # bnb + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + if model_args.quantization_bit == 8: require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)