From 79f301a2c673586c230f4ea843ed6e314220bba9 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 21 Dec 2023 01:19:22 +0800 Subject: [PATCH] fix ds zero3 check Former-commit-id: 7f50705b1d821d287bd854211319f697f57b25db --- src/llmtuner/model/patcher.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 03e4e5a2..38887b70 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -50,10 +50,10 @@ def configure_quantization( r""" Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) """ - if is_deepspeed_zero3_enabled(): - raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") - if getattr(config, "quantization_config", None): # gptq or awq + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + config_kwargs["device_map"] = {"": get_current_device()} quantization_config = getattr(config, "quantization_config", None) logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1))) @@ -76,6 +76,9 @@ def configure_quantization( logger.info("Quantizing model to {} bit.".format(finetuning_args.export_quantization_bit)) elif model_args.quantization_bit is not None: # bnb + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + if model_args.quantization_bit == 8: require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)