mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-19 13:20:36 +08:00
support rank0 logger
This commit is contained in:
@@ -28,8 +28,8 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import FILEEXT2TYPE
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import get_current_device
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@unique
|
||||
@@ -109,7 +109,7 @@ def configure_quantization(
|
||||
"""
|
||||
if getattr(config, "quantization_config", None): # ptq
|
||||
if model_args.quantization_bit is not None:
|
||||
logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.")
|
||||
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
|
||||
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
||||
@@ -130,7 +130,7 @@ def configure_quantization(
|
||||
quantization_config["bits"] = 2
|
||||
|
||||
quant_bits = quantization_config.get("bits", "?")
|
||||
logger.info(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
|
||||
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
|
||||
|
||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
|
||||
@@ -149,7 +149,7 @@ def configure_quantization(
|
||||
)
|
||||
init_kwargs["device_map"] = "auto"
|
||||
init_kwargs["max_memory"] = get_max_memory()
|
||||
logger.info(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
|
||||
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
|
||||
|
||||
elif model_args.quantization_bit is not None: # on-the-fly
|
||||
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
|
||||
@@ -179,7 +179,7 @@ def configure_quantization(
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||
|
||||
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
|
||||
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
|
||||
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
|
||||
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
|
||||
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
|
||||
@@ -191,7 +191,7 @@ def configure_quantization(
|
||||
init_kwargs["quantization_config"] = HqqConfig(
|
||||
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
|
||||
) # use ATEN kernel (axis=0) for performance
|
||||
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
|
||||
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
|
||||
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
|
||||
if model_args.quantization_bit != 8:
|
||||
raise ValueError("EETQ only accepts 8-bit quantization.")
|
||||
@@ -201,4 +201,4 @@ def configure_quantization(
|
||||
|
||||
require_version("eetq", "To fix: pip install eetq")
|
||||
init_kwargs["quantization_config"] = EetqConfig()
|
||||
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")
|
||||
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")
|
||||
|
||||
Reference in New Issue
Block a user