improve lora+ impl.

Former-commit-id: 332bad25455a70ad9204e7dd384bb086d789aa39
This commit is contained in:
hiyouga
2024-03-13 23:32:51 +08:00
parent 870b4b7285
commit 4ef67ed4dd
12 changed files with 165 additions and 169 deletions

View File

@@ -1,5 +1,5 @@
from .loader import load_model, load_model_and_tokenizer, load_tokenizer
from .utils import load_valuehead_params
from .utils import find_all_linear_modules, load_valuehead_params
__all__ = [
@@ -7,4 +7,5 @@ __all__ = [
"load_model_and_tokenizer",
"load_tokenizer",
"load_valuehead_params",
"find_all_linear_modules",
]

View File

@@ -5,7 +5,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras.logging import get_logger
from .utils import find_all_linear_modules, find_expanded_modules
from .utils import QuantizationMethod, find_all_linear_modules, find_expanded_modules
if TYPE_CHECKING:
@@ -129,9 +129,9 @@ def init_adapter(
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
if finetuning_args.use_dora:
if getattr(model, "quantization_method", None):
raise ValueError("DoRA is currently not compatible with quantized models.")
if finetuning_args.use_dora and getattr(model, "quantization_method", None) is not None:
if getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES:
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
peft_kwargs = {
"r": finetuning_args.lora_rank,

View File

@@ -109,10 +109,6 @@ def load_model(
if not is_trainable:
model.requires_grad_(False)
if not getattr(model, "quantization_method", None):
for param in filter(lambda p: p.device.type == "cuda", model.parameters()):
param.data = param.data.to(model_args.compute_dtype)
model.eval()
else:
model.train()

View File

@@ -18,6 +18,7 @@ from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch
from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl
from .utils import QuantizationMethod
if TYPE_CHECKING:
@@ -173,10 +174,10 @@ def _configure_quantization(
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == "gptq":
if quant_method == QuantizationMethod.GPTQ:
quantization_config["use_exllama"] = False # disable exllama
if quant_method == "aqlm":
if quant_method == QuantizationMethod.AQLM:
require_version(
"transformers>=4.39.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
)
@@ -205,7 +206,7 @@ def _configure_quantization(
elif model_args.quantization_bit is not None: # bnb
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")

View File

@@ -1,3 +1,4 @@
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List
import torch
@@ -17,6 +18,18 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
@unique
class QuantizationMethod(str, Enum):
r"""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
r"""
Finds all available modules to apply lora.
@@ -24,7 +37,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
quantization_method = getattr(model, "quantization_method", None)
if quantization_method is None:
linear_cls = torch.nn.Linear
elif quantization_method == "bitsandbytes":
elif quantization_method == QuantizationMethod.BITS_AND_BYTES:
import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt