mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
optimize aqlm training
Former-commit-id: d3d3dac7070eb9055bcdc91eaf53f5b3741c0bda
This commit is contained in:
parent
0f2250b831
commit
c776cdfc3e
@ -1,3 +1,4 @@
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
@ -86,6 +87,16 @@ def load_model(
|
||||
logger.warning("Unsloth does not support loading adapters.")
|
||||
|
||||
if model is None:
|
||||
model_init_context = nullcontext()
|
||||
if is_trainable and getattr(config, "quantization_config", None):
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
if quantization_config.get("quant_method", None) == "aqlm":
|
||||
import aqlm # type: ignore
|
||||
|
||||
model_init_context = aqlm.optimize_for_training()
|
||||
logger.info("Optimize for AQLM training.") # https://github.com/Vahe1994/AQLM/issues/38
|
||||
|
||||
with model_init_context:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
|
||||
|
||||
patch_model(model, tokenizer, model_args, is_trainable)
|
||||
|
@ -159,25 +159,25 @@ def _configure_quantization(
|
||||
r"""
|
||||
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||
"""
|
||||
if getattr(config, "quantization_config", None): # gptq
|
||||
if getattr(config, "quantization_config", None): # ptq
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
|
||||
if quant_method == "gptq":
|
||||
quantization_config["use_exllama"] = False # disable exllama
|
||||
|
||||
if quantization_config.get("quant_method", None) == "aqlm":
|
||||
if quant_method == "aqlm":
|
||||
require_version(
|
||||
"transformers>=4.39.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
||||
)
|
||||
quantization_config["bits"] = 2
|
||||
|
||||
logger.info(
|
||||
"Loading {}-bit {}-quantized model.".format(
|
||||
quantization_config.get("bits", "?"), str(quantization_config.get("quant_method", "")).upper()
|
||||
)
|
||||
)
|
||||
quant_bits = quantization_config.get("bits", "?")
|
||||
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
|
||||
|
||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||
@ -213,6 +213,7 @@ def _configure_quantization(
|
||||
bnb_4bit_quant_type=model_args.quantization_type,
|
||||
)
|
||||
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
|
||||
@ -285,10 +286,13 @@ def patch_config(
|
||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
init_kwargs["low_cpu_mem_usage"] = True
|
||||
if "device_map" not in init_kwargs:
|
||||
if is_trainable:
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
elif model_args.export_dir is None:
|
||||
init_kwargs["device_map"] = "auto"
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": "cpu"}
|
||||
|
||||
|
||||
def patch_model(
|
||||
|
Loading…
x
Reference in New Issue
Block a user