mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 20:30:36 +08:00
improve lora+ impl.
Former-commit-id: 332bad25455a70ad9204e7dd384bb086d789aa39
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from .loader import load_model, load_model_and_tokenizer, load_tokenizer
|
||||
from .utils import load_valuehead_params
|
||||
from .utils import find_all_linear_modules, load_valuehead_params
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -7,4 +7,5 @@ __all__ = [
|
||||
"load_model_and_tokenizer",
|
||||
"load_tokenizer",
|
||||
"load_valuehead_params",
|
||||
"find_all_linear_modules",
|
||||
]
|
||||
|
||||
@@ -5,7 +5,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .utils import find_all_linear_modules, find_expanded_modules
|
||||
from .utils import QuantizationMethod, find_all_linear_modules, find_expanded_modules
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -129,9 +129,9 @@ def init_adapter(
|
||||
if finetuning_args.use_llama_pro:
|
||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
||||
|
||||
if finetuning_args.use_dora:
|
||||
if getattr(model, "quantization_method", None):
|
||||
raise ValueError("DoRA is currently not compatible with quantized models.")
|
||||
if finetuning_args.use_dora and getattr(model, "quantization_method", None) is not None:
|
||||
if getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES:
|
||||
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
|
||||
|
||||
peft_kwargs = {
|
||||
"r": finetuning_args.lora_rank,
|
||||
|
||||
@@ -109,10 +109,6 @@ def load_model(
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False)
|
||||
if not getattr(model, "quantization_method", None):
|
||||
for param in filter(lambda p: p.device.type == "cuda", model.parameters()):
|
||||
param.data = param.data.to(model_args.compute_dtype)
|
||||
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
|
||||
@@ -18,6 +18,7 @@ from ..extras.misc import get_current_device, infer_optim_dtype
|
||||
from ..extras.packages import is_flash_attn2_available
|
||||
from ..extras.patches.llama_patch import apply_llama_patch
|
||||
from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl
|
||||
from .utils import QuantizationMethod
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -173,10 +174,10 @@ def _configure_quantization(
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
|
||||
if quant_method == "gptq":
|
||||
if quant_method == QuantizationMethod.GPTQ:
|
||||
quantization_config["use_exllama"] = False # disable exllama
|
||||
|
||||
if quant_method == "aqlm":
|
||||
if quant_method == QuantizationMethod.AQLM:
|
||||
require_version(
|
||||
"transformers>=4.39.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
||||
)
|
||||
@@ -205,7 +206,7 @@ def _configure_quantization(
|
||||
|
||||
elif model_args.quantization_bit is not None: # bnb
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
import torch
|
||||
@@ -17,6 +18,18 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@unique
|
||||
class QuantizationMethod(str, Enum):
|
||||
r"""
|
||||
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
|
||||
"""
|
||||
|
||||
BITS_AND_BYTES = "bitsandbytes"
|
||||
GPTQ = "gptq"
|
||||
AWQ = "awq"
|
||||
AQLM = "aqlm"
|
||||
|
||||
|
||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Finds all available modules to apply lora.
|
||||
@@ -24,7 +37,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||
quantization_method = getattr(model, "quantization_method", None)
|
||||
if quantization_method is None:
|
||||
linear_cls = torch.nn.Linear
|
||||
elif quantization_method == "bitsandbytes":
|
||||
elif quantization_method == QuantizationMethod.BITS_AND_BYTES:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
|
||||
|
||||
Reference in New Issue
Block a user