mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
fix ds zero3 check
Former-commit-id: 7f50705b1d821d287bd854211319f697f57b25db
This commit is contained in:
parent
31cbc67986
commit
79f301a2c6
@ -50,10 +50,10 @@ def configure_quantization(
|
||||
r"""
|
||||
Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||
"""
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
|
||||
if getattr(config, "quantization_config", None): # gptq or awq
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
quantization_config = getattr(config, "quantization_config", None)
|
||||
logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1)))
|
||||
@ -76,6 +76,9 @@ def configure_quantization(
|
||||
logger.info("Quantizing model to {} bit.".format(finetuning_args.export_quantization_bit))
|
||||
|
||||
elif model_args.quantization_bit is not None: # bnb
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user