mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-19 13:20:36 +08:00
imporve log
This commit is contained in:
@@ -26,11 +26,10 @@ from datasets import load_dataset
|
||||
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import FILEEXT2TYPE
|
||||
from ...extras.misc import get_current_device
|
||||
from ...extras.misc import check_version, get_current_device
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -118,15 +117,15 @@ def configure_quantization(
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
|
||||
if quant_method == QuantizationMethod.GPTQ:
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
check_version("auto_gptq>=0.5.0", mandatory=True)
|
||||
quantization_config.pop("disable_exllama", None) # remove deprecated args
|
||||
quantization_config["use_exllama"] = False # disable exllama
|
||||
|
||||
if quant_method == QuantizationMethod.AWQ:
|
||||
require_version("autoawq", "To fix: pip install autoawq")
|
||||
check_version("autoawq", mandatory=True)
|
||||
|
||||
if quant_method == QuantizationMethod.AQLM:
|
||||
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
|
||||
check_version("aqlm>=1.1.0", mandatory=True)
|
||||
quantization_config["bits"] = 2
|
||||
|
||||
quant_bits = quantization_config.get("bits", "?")
|
||||
@@ -136,8 +135,8 @@ def configure_quantization(
|
||||
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
|
||||
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
|
||||
|
||||
require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
check_version("optimum>=1.17.0", mandatory=True)
|
||||
check_version("auto_gptq>=0.5.0", mandatory=True)
|
||||
from accelerate.utils import get_max_memory
|
||||
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
@@ -154,10 +153,10 @@ def configure_quantization(
|
||||
elif model_args.quantization_bit is not None: # on-the-fly
|
||||
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
check_version("bitsandbytes>=0.37.0", mandatory=True)
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
check_version("bitsandbytes>=0.39.0", mandatory=True)
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
@@ -175,7 +174,7 @@ def configure_quantization(
|
||||
if model_args.quantization_bit != 4:
|
||||
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
|
||||
|
||||
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||
check_version("bitsandbytes>=0.43.0", mandatory=True)
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||
|
||||
@@ -187,7 +186,7 @@ def configure_quantization(
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
|
||||
|
||||
require_version("hqq", "To fix: pip install hqq")
|
||||
check_version("hqq", mandatory=True)
|
||||
init_kwargs["quantization_config"] = HqqConfig(
|
||||
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
|
||||
) # use ATEN kernel (axis=0) for performance
|
||||
@@ -199,6 +198,6 @@ def configure_quantization(
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
|
||||
|
||||
require_version("eetq", "To fix: pip install eetq")
|
||||
check_version("eetq", mandatory=True)
|
||||
init_kwargs["quantization_config"] = EetqConfig()
|
||||
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")
|
||||
|
||||
Reference in New Issue
Block a user