mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 04:32:50 +08:00
optimize aqlm training
Former-commit-id: d3d3dac7070eb9055bcdc91eaf53f5b3741c0bda
This commit is contained in:
parent
0f2250b831
commit
c776cdfc3e
@ -1,3 +1,4 @@
|
|||||||
|
from contextlib import nullcontext
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
@ -86,7 +87,17 @@ def load_model(
|
|||||||
logger.warning("Unsloth does not support loading adapters.")
|
logger.warning("Unsloth does not support loading adapters.")
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
|
model_init_context = nullcontext()
|
||||||
|
if is_trainable and getattr(config, "quantization_config", None):
|
||||||
|
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||||
|
if quantization_config.get("quant_method", None) == "aqlm":
|
||||||
|
import aqlm # type: ignore
|
||||||
|
|
||||||
|
model_init_context = aqlm.optimize_for_training()
|
||||||
|
logger.info("Optimize for AQLM training.") # https://github.com/Vahe1994/AQLM/issues/38
|
||||||
|
|
||||||
|
with model_init_context:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
|
||||||
|
|
||||||
patch_model(model, tokenizer, model_args, is_trainable)
|
patch_model(model, tokenizer, model_args, is_trainable)
|
||||||
register_autoclass(config, model, tokenizer)
|
register_autoclass(config, model, tokenizer)
|
||||||
|
@ -159,25 +159,25 @@ def _configure_quantization(
|
|||||||
r"""
|
r"""
|
||||||
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||||
"""
|
"""
|
||||||
if getattr(config, "quantization_config", None): # gptq
|
if getattr(config, "quantization_config", None): # ptq
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||||
|
|
||||||
|
init_kwargs["device_map"] = {"": get_current_device()}
|
||||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||||
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
|
quant_method = quantization_config.get("quant_method", "")
|
||||||
|
|
||||||
|
if quant_method == "gptq":
|
||||||
quantization_config["use_exllama"] = False # disable exllama
|
quantization_config["use_exllama"] = False # disable exllama
|
||||||
|
|
||||||
if quantization_config.get("quant_method", None) == "aqlm":
|
if quant_method == "aqlm":
|
||||||
require_version(
|
require_version(
|
||||||
"transformers>=4.39.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
"transformers>=4.39.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
||||||
)
|
)
|
||||||
quantization_config["bits"] = 2
|
quantization_config["bits"] = 2
|
||||||
|
|
||||||
logger.info(
|
quant_bits = quantization_config.get("bits", "?")
|
||||||
"Loading {}-bit {}-quantized model.".format(
|
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
|
||||||
quantization_config.get("bits", "?"), str(quantization_config.get("quant_method", "")).upper()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||||
@ -213,6 +213,7 @@ def _configure_quantization(
|
|||||||
bnb_4bit_quant_type=model_args.quantization_type,
|
bnb_4bit_quant_type=model_args.quantization_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
init_kwargs["device_map"] = {"": get_current_device()}
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
|
|
||||||
@ -285,10 +286,13 @@ def patch_config(
|
|||||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||||
if not is_deepspeed_zero3_enabled():
|
if not is_deepspeed_zero3_enabled():
|
||||||
init_kwargs["low_cpu_mem_usage"] = True
|
init_kwargs["low_cpu_mem_usage"] = True
|
||||||
if is_trainable:
|
if "device_map" not in init_kwargs:
|
||||||
init_kwargs["device_map"] = {"": get_current_device()}
|
if is_trainable:
|
||||||
elif model_args.export_dir is None:
|
init_kwargs["device_map"] = {"": get_current_device()}
|
||||||
init_kwargs["device_map"] = "auto"
|
elif model_args.export_dir is None:
|
||||||
|
init_kwargs["device_map"] = "auto"
|
||||||
|
else:
|
||||||
|
init_kwargs["device_map"] = {"": "cpu"}
|
||||||
|
|
||||||
|
|
||||||
def patch_model(
|
def patch_model(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user