From 23a875a8b1e5118466c3e75d27124de3186c191e Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 20 Dec 2023 18:27:16 +0800 Subject: [PATCH] improve quantization Former-commit-id: 624cc212819b7cd16295c72084cd454b67cf89a6 --- src/llmtuner/model/patcher.py | 46 ++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index ec5f0ddd..03e4e5a2 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -47,33 +47,18 @@ def configure_quantization( model_args: "ModelArguments", finetuning_args: "FinetuningArguments" ): + r""" + Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) + """ + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + if getattr(config, "quantization_config", None): # gptq or awq - model_args.quantization_bit = None # remove bnb quantization config_kwargs["device_map"] = {"": get_current_device()} quantization_config = getattr(config, "quantization_config", None) logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1))) - if model_args.quantization_bit is not None: # bnb - if is_deepspeed_zero3_enabled(): - raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") - - if model_args.quantization_bit == 8: - require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") - config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) - - if model_args.quantization_bit == 4: - require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") - config_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=model_args.compute_dtype, - bnb_4bit_use_double_quant=model_args.double_quantization, - bnb_4bit_quant_type=model_args.quantization_type - ) - - config_kwargs["device_map"] = {"": get_current_device()} - logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) - - if finetuning_args.export_quantization_bit is not None: # gptq + elif finetuning_args.export_quantization_bit is not None: # gptq require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") from accelerate.utils import get_max_memory @@ -90,6 +75,23 @@ def configure_quantization( config_kwargs["max_memory"] = get_max_memory() logger.info("Quantizing model to {} bit.".format(finetuning_args.export_quantization_bit)) + elif model_args.quantization_bit is not None: # bnb + if model_args.quantization_bit == 8: + require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + + elif model_args.quantization_bit == 4: + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + config_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.compute_dtype, + bnb_4bit_use_double_quant=model_args.double_quantization, + bnb_4bit_quant_type=model_args.quantization_type + ) + + config_kwargs["device_map"] = {"": get_current_device()} + logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) + def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool): if model_args.rope_scaling is not None: