From c776cdfc3e392aaacaba2e6a778dc5b4ce5cfe1e Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 5 Mar 2024 18:35:41 +0800 Subject: [PATCH] optimize aqlm training Former-commit-id: d3d3dac7070eb9055bcdc91eaf53f5b3741c0bda --- src/llmtuner/model/loader.py | 13 ++++++++++++- src/llmtuner/model/patcher.py | 28 ++++++++++++++++------------ 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 0760e792..45260310 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer @@ -86,7 +87,17 @@ def load_model( logger.warning("Unsloth does not support loading adapters.") if model is None: - model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs) + model_init_context = nullcontext() + if is_trainable and getattr(config, "quantization_config", None): + quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) + if quantization_config.get("quant_method", None) == "aqlm": + import aqlm # type: ignore + + model_init_context = aqlm.optimize_for_training() + logger.info("Optimize for AQLM training.") # https://github.com/Vahe1994/AQLM/issues/38 + + with model_init_context: + model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs) patch_model(model, tokenizer, model_args, is_trainable) register_autoclass(config, model, tokenizer) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 5a79387c..c61f28f0 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -159,25 +159,25 @@ def _configure_quantization( r""" Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) """ - if getattr(config, "quantization_config", None): # gptq + if getattr(config, "quantization_config", None): # ptq if is_deepspeed_zero3_enabled(): raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + init_kwargs["device_map"] = {"": get_current_device()} quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) - if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4: + quant_method = quantization_config.get("quant_method", "") + + if quant_method == "gptq": quantization_config["use_exllama"] = False # disable exllama - if quantization_config.get("quant_method", None) == "aqlm": + if quant_method == "aqlm": require_version( "transformers>=4.39.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git" ) quantization_config["bits"] = 2 - logger.info( - "Loading {}-bit {}-quantized model.".format( - quantization_config.get("bits", "?"), str(quantization_config.get("quant_method", "")).upper() - ) - ) + quant_bits = quantization_config.get("bits", "?") + logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) elif model_args.export_quantization_bit is not None: # auto-gptq require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") @@ -213,6 +213,7 @@ def _configure_quantization( bnb_4bit_quant_type=model_args.quantization_type, ) + init_kwargs["device_map"] = {"": get_current_device()} logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) @@ -285,10 +286,13 @@ def patch_config( init_kwargs["torch_dtype"] = model_args.compute_dtype if not is_deepspeed_zero3_enabled(): init_kwargs["low_cpu_mem_usage"] = True - if is_trainable: - init_kwargs["device_map"] = {"": get_current_device()} - elif model_args.export_dir is None: - init_kwargs["device_map"] = "auto" + if "device_map" not in init_kwargs: + if is_trainable: + init_kwargs["device_map"] = {"": get_current_device()} + elif model_args.export_dir is None: + init_kwargs["device_map"] = "auto" + else: + init_kwargs["device_map"] = {"": "cpu"} def patch_model(