optimize aqlm training

Former-commit-id: d3d3dac7070eb9055bcdc91eaf53f5b3741c0bda
This commit is contained in:
hiyouga 2024-03-05 18:35:41 +08:00
parent 0f2250b831
commit c776cdfc3e
2 changed files with 28 additions and 13 deletions

View File

@ -1,3 +1,4 @@
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
@ -86,6 +87,16 @@ def load_model(
logger.warning("Unsloth does not support loading adapters.") logger.warning("Unsloth does not support loading adapters.")
if model is None: if model is None:
model_init_context = nullcontext()
if is_trainable and getattr(config, "quantization_config", None):
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
if quantization_config.get("quant_method", None) == "aqlm":
import aqlm # type: ignore
model_init_context = aqlm.optimize_for_training()
logger.info("Optimize for AQLM training.") # https://github.com/Vahe1994/AQLM/issues/38
with model_init_context:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs) model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
patch_model(model, tokenizer, model_args, is_trainable) patch_model(model, tokenizer, model_args, is_trainable)

View File

@ -159,25 +159,25 @@ def _configure_quantization(
r""" r"""
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
""" """
if getattr(config, "quantization_config", None): # gptq if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
init_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4: quant_method = quantization_config.get("quant_method", "")
if quant_method == "gptq":
quantization_config["use_exllama"] = False # disable exllama quantization_config["use_exllama"] = False # disable exllama
if quantization_config.get("quant_method", None) == "aqlm": if quant_method == "aqlm":
require_version( require_version(
"transformers>=4.39.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git" "transformers>=4.39.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
) )
quantization_config["bits"] = 2 quantization_config["bits"] = 2
logger.info( quant_bits = quantization_config.get("bits", "?")
"Loading {}-bit {}-quantized model.".format( logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
quantization_config.get("bits", "?"), str(quantization_config.get("quant_method", "")).upper()
)
)
elif model_args.export_quantization_bit is not None: # auto-gptq elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
@ -213,6 +213,7 @@ def _configure_quantization(
bnb_4bit_quant_type=model_args.quantization_type, bnb_4bit_quant_type=model_args.quantization_type,
) )
init_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
@ -285,10 +286,13 @@ def patch_config(
init_kwargs["torch_dtype"] = model_args.compute_dtype init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled(): if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = True init_kwargs["low_cpu_mem_usage"] = True
if "device_map" not in init_kwargs:
if is_trainable: if is_trainable:
init_kwargs["device_map"] = {"": get_current_device()} init_kwargs["device_map"] = {"": get_current_device()}
elif model_args.export_dir is None: elif model_args.export_dir is None:
init_kwargs["device_map"] = "auto" init_kwargs["device_map"] = "auto"
else:
init_kwargs["device_map"] = {"": "cpu"}
def patch_model( def patch_model(