mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-05 03:05:58 +08:00
[v1] add LoRA/Freeze support and merge workflow (#10157)
This commit is contained in:
@@ -12,14 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
import re
|
||||
from typing import Literal, TypedDict, Union
|
||||
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
import torch
|
||||
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
||||
|
||||
from ...config import InputArgument, get_args
|
||||
from ...core.model_engine import ModelEngine
|
||||
from ...utils import logging
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import HFModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LoraConfigDict(TypedDict, total=False):
|
||||
name: Literal["lora"]
|
||||
"""Plugin name."""
|
||||
@@ -27,8 +35,28 @@ class LoraConfigDict(TypedDict, total=False):
|
||||
"""Lora rank."""
|
||||
lora_alpha: int
|
||||
"""Lora alpha."""
|
||||
target_modules: list[str]
|
||||
lora_dropout: float
|
||||
"""Lora dropout."""
|
||||
target_modules: Union[list[str], str]
|
||||
"""Target modules."""
|
||||
use_rslora: bool
|
||||
"""Use RS-LoRA."""
|
||||
use_dora: bool
|
||||
"""Use DoRA."""
|
||||
modules_to_save: list[str]
|
||||
"""Modules to save."""
|
||||
adapter_name_or_path: Union[list[str], str]
|
||||
"""Path to the adapter(s)."""
|
||||
export_dir: str
|
||||
"""Path to the export directory."""
|
||||
export_size: int
|
||||
"""Shard size for the export model."""
|
||||
export_hub_model_id: str
|
||||
"""Hub model ID for the export model."""
|
||||
infer_dtype: Literal["auto", "float16", "float32", "bfloat16"]
|
||||
"""Inference data type for the export model."""
|
||||
export_legacy_format: bool
|
||||
"""Use legacy format for the export model."""
|
||||
|
||||
|
||||
class FreezeConfigDict(TypedDict, total=False):
|
||||
@@ -36,22 +64,280 @@ class FreezeConfigDict(TypedDict, total=False):
|
||||
"""Plugin name."""
|
||||
freeze_trainable_layers: int
|
||||
"""Freeze trainable layers."""
|
||||
freeze_trainable_modules: list[str] | None
|
||||
freeze_trainable_modules: Union[list[str], str]
|
||||
"""Freeze trainable modules."""
|
||||
freeze_extra_modules: list[str]
|
||||
"""Freeze extra modules."""
|
||||
cast_trainable_params_to_fp32: bool
|
||||
"""Cast trainable params to fp32."""
|
||||
|
||||
|
||||
class PeftPlugin(BasePlugin):
|
||||
def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel:
|
||||
return super().__call__(model, config)
|
||||
return super().__call__(model, config, is_train)
|
||||
|
||||
|
||||
def _find_all_linear_modules(model: HFModel) -> list[str]:
|
||||
r"""Find all available modules to apply LoRA."""
|
||||
forbidden_modules = {"lm_head", "output_layer", "output"}
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||
continue
|
||||
|
||||
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def merge_adapters(model: HFModel, adapter_name_or_path: Union[list[str], str]) -> HFModel:
|
||||
if not isinstance(adapter_name_or_path, list):
|
||||
adapter_name_or_path = [adapter_name_or_path]
|
||||
|
||||
for adapter_path in adapter_name_or_path:
|
||||
model = PeftModel.from_pretrained(model, adapter_path)
|
||||
model = model.merge_and_unload()
|
||||
logger.info_rank0(f"Merged adapter from {adapter_path}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is_train: bool) -> HFModel:
|
||||
r"""Loads adapter(s) into the model.
|
||||
|
||||
Determine adapter usage based on mode:
|
||||
- Training: Load the single adapter for continued training.
|
||||
- Inference: Merge all adapters to clean up the model.
|
||||
- Unmergeable: Keep the single adapter active without merging.
|
||||
"""
|
||||
if not isinstance(adapter_name_or_path, list):
|
||||
adapter_name_or_path = [adapter_name_or_path]
|
||||
|
||||
# TODO
|
||||
# Adapters fix for deepspeed and quant
|
||||
# Adapters fix for vision
|
||||
|
||||
if is_train and len(adapter_name_or_path) > 1:
|
||||
raise ValueError(
|
||||
"When `adapter_name_or_path` is provided for training, only a single LoRA adapter is supported. "
|
||||
"Training will continue on the specified adapter. "
|
||||
"Please merge multiple adapters before starting a new LoRA adapter."
|
||||
)
|
||||
|
||||
if is_train:
|
||||
adapter_to_merge = []
|
||||
adapter_to_resume = adapter_name_or_path[0]
|
||||
else:
|
||||
adapter_to_merge = adapter_name_or_path
|
||||
adapter_to_resume = None
|
||||
|
||||
if adapter_to_merge:
|
||||
model = merge_adapters(model, adapter_to_merge)
|
||||
|
||||
if adapter_to_resume is not None:
|
||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_train)
|
||||
if is_train:
|
||||
logger.info_rank0(
|
||||
f"Resuming training from existing LoRA adapter at {adapter_to_resume}. "
|
||||
"LoRA hyperparameters will be loaded from the adapter itself; "
|
||||
"the current LoRA configuration will be ignored. "
|
||||
"Merge the adapter into the base model before training if you want to start a new adapter."
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@PeftPlugin("lora").register()
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
|
||||
peft_config = LoraConfig(**config)
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
|
||||
adapter_name_or_path = config.get("adapter_name_or_path")
|
||||
|
||||
if adapter_name_or_path:
|
||||
return load_adapter(model, adapter_name_or_path, is_train)
|
||||
|
||||
logger.info_rank0("Fine-tuning method: LoRA")
|
||||
|
||||
target_modules = config.get("target_modules", "all")
|
||||
|
||||
# Handle target modules
|
||||
if target_modules == "all":
|
||||
target_modules = _find_all_linear_modules(model)
|
||||
elif isinstance(target_modules, str):
|
||||
target_modules = [target_modules]
|
||||
|
||||
logger.info_rank0(f"LoRA target modules: {target_modules}")
|
||||
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=not is_train,
|
||||
r=config.get("r", 8),
|
||||
lora_alpha=config.get("lora_alpha", 16),
|
||||
lora_dropout=config.get("lora_dropout", 0.05),
|
||||
use_rslora=config.get("use_rslora", False),
|
||||
use_dora=config.get("use_dora", False),
|
||||
target_modules=target_modules,
|
||||
modules_to_save=config.get("modules_to_save", None),
|
||||
)
|
||||
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
if is_train:
|
||||
model.print_trainable_parameters()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@PeftPlugin("freeze").register()
|
||||
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
|
||||
raise NotImplementedError()
|
||||
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool = False) -> HFModel:
|
||||
logger.info_rank0("Fine-tuning method: Freeze")
|
||||
|
||||
if not is_train:
|
||||
return model
|
||||
|
||||
freeze_trainable_layers = config.get("freeze_trainable_layers", 2)
|
||||
freeze_trainable_modules = config.get("freeze_trainable_modules", ["all"])
|
||||
freeze_extra_modules = config.get("freeze_extra_modules", [])
|
||||
cast_trainable_params_to_fp32 = config.get("cast_trainable_params_to_fp32", True)
|
||||
|
||||
if isinstance(freeze_trainable_modules, str):
|
||||
freeze_trainable_modules = [module.strip() for module in freeze_trainable_modules.split(",")]
|
||||
|
||||
if isinstance(freeze_extra_modules, str):
|
||||
freeze_extra_modules = [module.strip() for module in freeze_extra_modules.split(",")]
|
||||
|
||||
# Get number of layers
|
||||
num_layers = (
|
||||
getattr(model.config, "num_hidden_layers", None)
|
||||
or getattr(model.config, "num_layers", None)
|
||||
or getattr(model.config, "n_layer", None)
|
||||
)
|
||||
|
||||
if not num_layers:
|
||||
raise ValueError("Current model does not support freeze tuning.")
|
||||
|
||||
if freeze_trainable_layers > 0:
|
||||
# last n layers
|
||||
trainable_layer_ids = range(max(0, num_layers - freeze_trainable_layers), num_layers)
|
||||
else:
|
||||
# first n layers
|
||||
trainable_layer_ids = range(min(-freeze_trainable_layers, num_layers))
|
||||
|
||||
# Identify hidden and non-hidden modules
|
||||
hidden_modules = set()
|
||||
non_hidden_modules = set()
|
||||
for name, _ in model.named_parameters():
|
||||
if ".0." in name:
|
||||
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
|
||||
elif ".1." in name:
|
||||
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
|
||||
|
||||
if re.search(r"\.\d+\.", name) is None:
|
||||
non_hidden_modules.add(name.split(".")[-2])
|
||||
|
||||
# Build list of trainable layer patterns
|
||||
trainable_layers = []
|
||||
for module_name in freeze_trainable_modules:
|
||||
if module_name == "all":
|
||||
for idx in trainable_layer_ids:
|
||||
trainable_layers.append(f".{idx:d}.")
|
||||
elif module_name in hidden_modules:
|
||||
for idx in trainable_layer_ids:
|
||||
trainable_layers.append(f".{idx:d}.{module_name}")
|
||||
else:
|
||||
raise ValueError(f"Module {module_name} not found in hidden modules: {hidden_modules}")
|
||||
|
||||
# Add extra modules
|
||||
if freeze_extra_modules:
|
||||
for module_name in freeze_extra_modules:
|
||||
if module_name in non_hidden_modules:
|
||||
trainable_layers.append(module_name)
|
||||
else:
|
||||
raise ValueError(f"Module {module_name} not found in non-hidden modules: {non_hidden_modules}")
|
||||
|
||||
# TODO
|
||||
# Multi-modal special handling
|
||||
|
||||
# Set requires_grad
|
||||
forbidden_modules = {"quant_state", "quantization_weight", "qweight", "qzeros", "scales"}
|
||||
for name, param in model.named_parameters():
|
||||
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
|
||||
forbidden_module in name for forbidden_module in forbidden_modules
|
||||
):
|
||||
param.requires_grad_(True)
|
||||
if cast_trainable_params_to_fp32:
|
||||
param.data = param.data.to(torch.float32) # Cast to fp32 for stability
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
logger.info_rank0(f"Set trainable layers: {trainable_layers}")
|
||||
|
||||
# Count trainable params for verification
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
all_params = sum(p.numel() for p in model.parameters())
|
||||
logger.info_rank0(
|
||||
f"trainable params: {trainable_params} || all params: {all_params} || trainable%: {100 * trainable_params / all_params:.4f}"
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def merge_and_export_model(args: InputArgument = None):
|
||||
model_args, _, _, _ = get_args(args)
|
||||
|
||||
export_config = model_args.peft_config
|
||||
if export_config is None:
|
||||
raise ValueError("Please specify peft_config to merge and export model.")
|
||||
|
||||
export_dir = export_config.get("export_dir")
|
||||
if export_dir is None:
|
||||
raise ValueError("Please specify export_dir.")
|
||||
|
||||
export_size = export_config.get("export_size", 5)
|
||||
export_hub_model_id = export_config.get("export_hub_model_id")
|
||||
infer_dtype = export_config.get("infer_dtype", "auto")
|
||||
export_legacy_format = export_config.get("export_legacy_format", False)
|
||||
|
||||
adapters = None
|
||||
if export_config.get("name") == "lora":
|
||||
adapters = export_config.get("adapter_name_or_path")
|
||||
else:
|
||||
raise ValueError("Currently merge and export model function is only supported for lora.")
|
||||
|
||||
if adapters is None:
|
||||
raise ValueError("Please set adapter_name_or_path to merge adapters into base model.")
|
||||
|
||||
logger.info_rank0("Loading model for export...")
|
||||
model_engine = ModelEngine(model_args, is_train=False)
|
||||
model = model_engine.model
|
||||
tokenizer = model_engine.processor
|
||||
|
||||
if infer_dtype == "auto":
|
||||
if model.config.torch_dtype == torch.float32 and torch.cuda.is_bf16_supported():
|
||||
model = model.to(torch.bfloat16)
|
||||
logger.info_rank0("Converted model to bfloat16.")
|
||||
else:
|
||||
target_dtype = getattr(torch, infer_dtype)
|
||||
model = model.to(target_dtype)
|
||||
logger.info_rank0(f"Converted model to {infer_dtype}.")
|
||||
|
||||
logger.info_rank0(f"Exporting model to {export_dir}...")
|
||||
model.save_pretrained(
|
||||
export_dir,
|
||||
max_shard_size=f"{export_size}GB",
|
||||
safe_serialization=not export_legacy_format,
|
||||
)
|
||||
if tokenizer is not None:
|
||||
try:
|
||||
if hasattr(tokenizer, "padding_side"):
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.save_pretrained(export_dir)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save tokenizer: {e}")
|
||||
|
||||
if export_hub_model_id:
|
||||
logger.info_rank0(f"Pushing to hub: {export_hub_model_id}...")
|
||||
model.push_to_hub(export_hub_model_id)
|
||||
if tokenizer is not None:
|
||||
tokenizer.push_to_hub(export_hub_model_id)
|
||||
|
||||
logger.info_rank0("Model exported successfully.")
|
||||
|
||||
Reference in New Issue
Block a user