mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-02-26 07:45:59 +08:00
[v1] add LoRA/Freeze support and merge workflow (#10157)
This commit is contained in:
38
examples/v1/train_freeze/train_freeze_sft.yaml
Normal file
38
examples/v1/train_freeze/train_freeze_sft.yaml
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
model: Qwen/Qwen3-4B
|
||||||
|
trust_remote_code: true
|
||||||
|
model_class: llm
|
||||||
|
|
||||||
|
template: qwen3_nothink
|
||||||
|
|
||||||
|
# Freeze Configuration
|
||||||
|
peft_config:
|
||||||
|
name: freeze
|
||||||
|
freeze_trainable_layers: 2 # Train the last 2 layers
|
||||||
|
freeze_trainable_modules: all # In these layers, train specific modules
|
||||||
|
freeze_extra_modules: null # Extra modules to train (e.g. embed_tokens, lm_head)
|
||||||
|
|
||||||
|
# Kernel Config
|
||||||
|
kernel_config:
|
||||||
|
name: auto
|
||||||
|
include_kernels: auto
|
||||||
|
|
||||||
|
# FSDP Config
|
||||||
|
dist_config:
|
||||||
|
name: fsdp2
|
||||||
|
dcp_path: null
|
||||||
|
|
||||||
|
### data
|
||||||
|
train_dataset: data/v1_sft_demo.yaml
|
||||||
|
|
||||||
|
### training
|
||||||
|
output_dir: ./outputs/test_freeze
|
||||||
|
micro_batch_size: 1
|
||||||
|
global_batch_size: 4
|
||||||
|
cutoff_len: 2048
|
||||||
|
learning_rate: 2.0e-5
|
||||||
|
bf16: false
|
||||||
|
max_steps: 10
|
||||||
|
|
||||||
|
### sample
|
||||||
|
sample_backend: hf
|
||||||
|
max_new_tokens: 128
|
||||||
7
examples/v1/train_lora/export_lora.yaml
Normal file
7
examples/v1/train_lora/export_lora.yaml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
model: Qwen/Qwen3-4B
|
||||||
|
peft_config:
|
||||||
|
name: lora
|
||||||
|
adapter_name_or_path: ./outputs/test_lora
|
||||||
|
export_dir: ./merge_lora_model
|
||||||
|
export_size: 5
|
||||||
|
infer_dtype: auto
|
||||||
39
examples/v1/train_lora/train_lora_sft.yaml
Normal file
39
examples/v1/train_lora/train_lora_sft.yaml
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
model: Qwen/Qwen3-4B
|
||||||
|
trust_remote_code: true
|
||||||
|
model_class: llm
|
||||||
|
|
||||||
|
template: qwen3_nothink
|
||||||
|
|
||||||
|
# PEFT Configuration
|
||||||
|
peft_config:
|
||||||
|
name: lora
|
||||||
|
r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.05
|
||||||
|
target_modules: all
|
||||||
|
|
||||||
|
# Kernel Config
|
||||||
|
kernel_config:
|
||||||
|
name: auto
|
||||||
|
include_kernels: auto
|
||||||
|
|
||||||
|
# FSDP Config
|
||||||
|
dist_config:
|
||||||
|
name: fsdp2
|
||||||
|
dcp_path: null
|
||||||
|
|
||||||
|
### data
|
||||||
|
train_dataset: data/v1_sft_demo.yaml
|
||||||
|
|
||||||
|
### training
|
||||||
|
output_dir: ./outputs/test_lora
|
||||||
|
micro_batch_size: 1
|
||||||
|
global_batch_size: 4
|
||||||
|
cutoff_len: 2048
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
bf16: true
|
||||||
|
max_steps: 10
|
||||||
|
|
||||||
|
### sample
|
||||||
|
sample_backend: hf
|
||||||
|
max_new_tokens: 128
|
||||||
@@ -204,6 +204,16 @@ class BaseTrainer:
|
|||||||
def save_model(self) -> None:
|
def save_model(self) -> None:
|
||||||
"""Save the model."""
|
"""Save the model."""
|
||||||
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
||||||
model_to_save.save_pretrained(self.args.output_dir)
|
state_dict = None
|
||||||
|
if self.args.dist_config is not None and self.args.dist_config.name == "fsdp2":
|
||||||
|
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
|
||||||
|
|
||||||
|
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
|
||||||
|
state_dict = get_model_state_dict(self.model, options=options)
|
||||||
|
|
||||||
|
if DistributedInterface().get_rank() != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
model_to_save.save_pretrained(self.args.output_dir, state_dict=state_dict)
|
||||||
self.renderer.processor.save_pretrained(self.args.output_dir)
|
self.renderer.processor.save_pretrained(self.args.output_dir)
|
||||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||||
|
|||||||
@@ -125,6 +125,11 @@ def launch():
|
|||||||
|
|
||||||
run_chat()
|
run_chat()
|
||||||
|
|
||||||
|
elif command == "merge":
|
||||||
|
from llamafactory.v1.plugins.model_plugins.peft import merge_and_export_model
|
||||||
|
|
||||||
|
merge_and_export_model()
|
||||||
|
|
||||||
elif command == "env":
|
elif command == "env":
|
||||||
raise NotImplementedError("Environment information is not implemented yet.")
|
raise NotImplementedError("Environment information is not implemented yet.")
|
||||||
|
|
||||||
|
|||||||
@@ -12,14 +12,22 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Literal, TypedDict
|
import re
|
||||||
|
from typing import Literal, TypedDict, Union
|
||||||
|
|
||||||
from peft import LoraConfig, PeftModel, get_peft_model
|
import torch
|
||||||
|
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
||||||
|
|
||||||
|
from ...config import InputArgument, get_args
|
||||||
|
from ...core.model_engine import ModelEngine
|
||||||
|
from ...utils import logging
|
||||||
from ...utils.plugin import BasePlugin
|
from ...utils.plugin import BasePlugin
|
||||||
from ...utils.types import HFModel
|
from ...utils.types import HFModel
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LoraConfigDict(TypedDict, total=False):
|
class LoraConfigDict(TypedDict, total=False):
|
||||||
name: Literal["lora"]
|
name: Literal["lora"]
|
||||||
"""Plugin name."""
|
"""Plugin name."""
|
||||||
@@ -27,8 +35,28 @@ class LoraConfigDict(TypedDict, total=False):
|
|||||||
"""Lora rank."""
|
"""Lora rank."""
|
||||||
lora_alpha: int
|
lora_alpha: int
|
||||||
"""Lora alpha."""
|
"""Lora alpha."""
|
||||||
target_modules: list[str]
|
lora_dropout: float
|
||||||
|
"""Lora dropout."""
|
||||||
|
target_modules: Union[list[str], str]
|
||||||
"""Target modules."""
|
"""Target modules."""
|
||||||
|
use_rslora: bool
|
||||||
|
"""Use RS-LoRA."""
|
||||||
|
use_dora: bool
|
||||||
|
"""Use DoRA."""
|
||||||
|
modules_to_save: list[str]
|
||||||
|
"""Modules to save."""
|
||||||
|
adapter_name_or_path: Union[list[str], str]
|
||||||
|
"""Path to the adapter(s)."""
|
||||||
|
export_dir: str
|
||||||
|
"""Path to the export directory."""
|
||||||
|
export_size: int
|
||||||
|
"""Shard size for the export model."""
|
||||||
|
export_hub_model_id: str
|
||||||
|
"""Hub model ID for the export model."""
|
||||||
|
infer_dtype: Literal["auto", "float16", "float32", "bfloat16"]
|
||||||
|
"""Inference data type for the export model."""
|
||||||
|
export_legacy_format: bool
|
||||||
|
"""Use legacy format for the export model."""
|
||||||
|
|
||||||
|
|
||||||
class FreezeConfigDict(TypedDict, total=False):
|
class FreezeConfigDict(TypedDict, total=False):
|
||||||
@@ -36,22 +64,280 @@ class FreezeConfigDict(TypedDict, total=False):
|
|||||||
"""Plugin name."""
|
"""Plugin name."""
|
||||||
freeze_trainable_layers: int
|
freeze_trainable_layers: int
|
||||||
"""Freeze trainable layers."""
|
"""Freeze trainable layers."""
|
||||||
freeze_trainable_modules: list[str] | None
|
freeze_trainable_modules: Union[list[str], str]
|
||||||
"""Freeze trainable modules."""
|
"""Freeze trainable modules."""
|
||||||
|
freeze_extra_modules: list[str]
|
||||||
|
"""Freeze extra modules."""
|
||||||
|
cast_trainable_params_to_fp32: bool
|
||||||
|
"""Cast trainable params to fp32."""
|
||||||
|
|
||||||
|
|
||||||
class PeftPlugin(BasePlugin):
|
class PeftPlugin(BasePlugin):
|
||||||
def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel:
|
def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel:
|
||||||
return super().__call__(model, config)
|
return super().__call__(model, config, is_train)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_all_linear_modules(model: HFModel) -> list[str]:
|
||||||
|
r"""Find all available modules to apply LoRA."""
|
||||||
|
forbidden_modules = {"lm_head", "output_layer", "output"}
|
||||||
|
module_names = set()
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
|
||||||
|
module_names.add(name.split(".")[-1])
|
||||||
|
|
||||||
|
return list(module_names)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_adapters(model: HFModel, adapter_name_or_path: Union[list[str], str]) -> HFModel:
|
||||||
|
if not isinstance(adapter_name_or_path, list):
|
||||||
|
adapter_name_or_path = [adapter_name_or_path]
|
||||||
|
|
||||||
|
for adapter_path in adapter_name_or_path:
|
||||||
|
model = PeftModel.from_pretrained(model, adapter_path)
|
||||||
|
model = model.merge_and_unload()
|
||||||
|
logger.info_rank0(f"Merged adapter from {adapter_path}")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is_train: bool) -> HFModel:
|
||||||
|
r"""Loads adapter(s) into the model.
|
||||||
|
|
||||||
|
Determine adapter usage based on mode:
|
||||||
|
- Training: Load the single adapter for continued training.
|
||||||
|
- Inference: Merge all adapters to clean up the model.
|
||||||
|
- Unmergeable: Keep the single adapter active without merging.
|
||||||
|
"""
|
||||||
|
if not isinstance(adapter_name_or_path, list):
|
||||||
|
adapter_name_or_path = [adapter_name_or_path]
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# Adapters fix for deepspeed and quant
|
||||||
|
# Adapters fix for vision
|
||||||
|
|
||||||
|
if is_train and len(adapter_name_or_path) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"When `adapter_name_or_path` is provided for training, only a single LoRA adapter is supported. "
|
||||||
|
"Training will continue on the specified adapter. "
|
||||||
|
"Please merge multiple adapters before starting a new LoRA adapter."
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_train:
|
||||||
|
adapter_to_merge = []
|
||||||
|
adapter_to_resume = adapter_name_or_path[0]
|
||||||
|
else:
|
||||||
|
adapter_to_merge = adapter_name_or_path
|
||||||
|
adapter_to_resume = None
|
||||||
|
|
||||||
|
if adapter_to_merge:
|
||||||
|
model = merge_adapters(model, adapter_to_merge)
|
||||||
|
|
||||||
|
if adapter_to_resume is not None:
|
||||||
|
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_train)
|
||||||
|
if is_train:
|
||||||
|
logger.info_rank0(
|
||||||
|
f"Resuming training from existing LoRA adapter at {adapter_to_resume}. "
|
||||||
|
"LoRA hyperparameters will be loaded from the adapter itself; "
|
||||||
|
"the current LoRA configuration will be ignored. "
|
||||||
|
"Merge the adapter into the base model before training if you want to start a new adapter."
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@PeftPlugin("lora").register()
|
@PeftPlugin("lora").register()
|
||||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
|
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
|
||||||
peft_config = LoraConfig(**config)
|
adapter_name_or_path = config.get("adapter_name_or_path")
|
||||||
|
|
||||||
|
if adapter_name_or_path:
|
||||||
|
return load_adapter(model, adapter_name_or_path, is_train)
|
||||||
|
|
||||||
|
logger.info_rank0("Fine-tuning method: LoRA")
|
||||||
|
|
||||||
|
target_modules = config.get("target_modules", "all")
|
||||||
|
|
||||||
|
# Handle target modules
|
||||||
|
if target_modules == "all":
|
||||||
|
target_modules = _find_all_linear_modules(model)
|
||||||
|
elif isinstance(target_modules, str):
|
||||||
|
target_modules = [target_modules]
|
||||||
|
|
||||||
|
logger.info_rank0(f"LoRA target modules: {target_modules}")
|
||||||
|
|
||||||
|
peft_config = LoraConfig(
|
||||||
|
task_type=TaskType.CAUSAL_LM,
|
||||||
|
inference_mode=not is_train,
|
||||||
|
r=config.get("r", 8),
|
||||||
|
lora_alpha=config.get("lora_alpha", 16),
|
||||||
|
lora_dropout=config.get("lora_dropout", 0.05),
|
||||||
|
use_rslora=config.get("use_rslora", False),
|
||||||
|
use_dora=config.get("use_dora", False),
|
||||||
|
target_modules=target_modules,
|
||||||
|
modules_to_save=config.get("modules_to_save", None),
|
||||||
|
)
|
||||||
|
|
||||||
model = get_peft_model(model, peft_config)
|
model = get_peft_model(model, peft_config)
|
||||||
|
|
||||||
|
if is_train:
|
||||||
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@PeftPlugin("freeze").register()
|
@PeftPlugin("freeze").register()
|
||||||
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
|
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool = False) -> HFModel:
|
||||||
raise NotImplementedError()
|
logger.info_rank0("Fine-tuning method: Freeze")
|
||||||
|
|
||||||
|
if not is_train:
|
||||||
|
return model
|
||||||
|
|
||||||
|
freeze_trainable_layers = config.get("freeze_trainable_layers", 2)
|
||||||
|
freeze_trainable_modules = config.get("freeze_trainable_modules", ["all"])
|
||||||
|
freeze_extra_modules = config.get("freeze_extra_modules", [])
|
||||||
|
cast_trainable_params_to_fp32 = config.get("cast_trainable_params_to_fp32", True)
|
||||||
|
|
||||||
|
if isinstance(freeze_trainable_modules, str):
|
||||||
|
freeze_trainable_modules = [module.strip() for module in freeze_trainable_modules.split(",")]
|
||||||
|
|
||||||
|
if isinstance(freeze_extra_modules, str):
|
||||||
|
freeze_extra_modules = [module.strip() for module in freeze_extra_modules.split(",")]
|
||||||
|
|
||||||
|
# Get number of layers
|
||||||
|
num_layers = (
|
||||||
|
getattr(model.config, "num_hidden_layers", None)
|
||||||
|
or getattr(model.config, "num_layers", None)
|
||||||
|
or getattr(model.config, "n_layer", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not num_layers:
|
||||||
|
raise ValueError("Current model does not support freeze tuning.")
|
||||||
|
|
||||||
|
if freeze_trainable_layers > 0:
|
||||||
|
# last n layers
|
||||||
|
trainable_layer_ids = range(max(0, num_layers - freeze_trainable_layers), num_layers)
|
||||||
|
else:
|
||||||
|
# first n layers
|
||||||
|
trainable_layer_ids = range(min(-freeze_trainable_layers, num_layers))
|
||||||
|
|
||||||
|
# Identify hidden and non-hidden modules
|
||||||
|
hidden_modules = set()
|
||||||
|
non_hidden_modules = set()
|
||||||
|
for name, _ in model.named_parameters():
|
||||||
|
if ".0." in name:
|
||||||
|
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
|
||||||
|
elif ".1." in name:
|
||||||
|
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
|
||||||
|
|
||||||
|
if re.search(r"\.\d+\.", name) is None:
|
||||||
|
non_hidden_modules.add(name.split(".")[-2])
|
||||||
|
|
||||||
|
# Build list of trainable layer patterns
|
||||||
|
trainable_layers = []
|
||||||
|
for module_name in freeze_trainable_modules:
|
||||||
|
if module_name == "all":
|
||||||
|
for idx in trainable_layer_ids:
|
||||||
|
trainable_layers.append(f".{idx:d}.")
|
||||||
|
elif module_name in hidden_modules:
|
||||||
|
for idx in trainable_layer_ids:
|
||||||
|
trainable_layers.append(f".{idx:d}.{module_name}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Module {module_name} not found in hidden modules: {hidden_modules}")
|
||||||
|
|
||||||
|
# Add extra modules
|
||||||
|
if freeze_extra_modules:
|
||||||
|
for module_name in freeze_extra_modules:
|
||||||
|
if module_name in non_hidden_modules:
|
||||||
|
trainable_layers.append(module_name)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Module {module_name} not found in non-hidden modules: {non_hidden_modules}")
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# Multi-modal special handling
|
||||||
|
|
||||||
|
# Set requires_grad
|
||||||
|
forbidden_modules = {"quant_state", "quantization_weight", "qweight", "qzeros", "scales"}
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
|
||||||
|
forbidden_module in name for forbidden_module in forbidden_modules
|
||||||
|
):
|
||||||
|
param.requires_grad_(True)
|
||||||
|
if cast_trainable_params_to_fp32:
|
||||||
|
param.data = param.data.to(torch.float32) # Cast to fp32 for stability
|
||||||
|
else:
|
||||||
|
param.requires_grad_(False)
|
||||||
|
|
||||||
|
logger.info_rank0(f"Set trainable layers: {trainable_layers}")
|
||||||
|
|
||||||
|
# Count trainable params for verification
|
||||||
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
all_params = sum(p.numel() for p in model.parameters())
|
||||||
|
logger.info_rank0(
|
||||||
|
f"trainable params: {trainable_params} || all params: {all_params} || trainable%: {100 * trainable_params / all_params:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def merge_and_export_model(args: InputArgument = None):
|
||||||
|
model_args, _, _, _ = get_args(args)
|
||||||
|
|
||||||
|
export_config = model_args.peft_config
|
||||||
|
if export_config is None:
|
||||||
|
raise ValueError("Please specify peft_config to merge and export model.")
|
||||||
|
|
||||||
|
export_dir = export_config.get("export_dir")
|
||||||
|
if export_dir is None:
|
||||||
|
raise ValueError("Please specify export_dir.")
|
||||||
|
|
||||||
|
export_size = export_config.get("export_size", 5)
|
||||||
|
export_hub_model_id = export_config.get("export_hub_model_id")
|
||||||
|
infer_dtype = export_config.get("infer_dtype", "auto")
|
||||||
|
export_legacy_format = export_config.get("export_legacy_format", False)
|
||||||
|
|
||||||
|
adapters = None
|
||||||
|
if export_config.get("name") == "lora":
|
||||||
|
adapters = export_config.get("adapter_name_or_path")
|
||||||
|
else:
|
||||||
|
raise ValueError("Currently merge and export model function is only supported for lora.")
|
||||||
|
|
||||||
|
if adapters is None:
|
||||||
|
raise ValueError("Please set adapter_name_or_path to merge adapters into base model.")
|
||||||
|
|
||||||
|
logger.info_rank0("Loading model for export...")
|
||||||
|
model_engine = ModelEngine(model_args, is_train=False)
|
||||||
|
model = model_engine.model
|
||||||
|
tokenizer = model_engine.processor
|
||||||
|
|
||||||
|
if infer_dtype == "auto":
|
||||||
|
if model.config.torch_dtype == torch.float32 and torch.cuda.is_bf16_supported():
|
||||||
|
model = model.to(torch.bfloat16)
|
||||||
|
logger.info_rank0("Converted model to bfloat16.")
|
||||||
|
else:
|
||||||
|
target_dtype = getattr(torch, infer_dtype)
|
||||||
|
model = model.to(target_dtype)
|
||||||
|
logger.info_rank0(f"Converted model to {infer_dtype}.")
|
||||||
|
|
||||||
|
logger.info_rank0(f"Exporting model to {export_dir}...")
|
||||||
|
model.save_pretrained(
|
||||||
|
export_dir,
|
||||||
|
max_shard_size=f"{export_size}GB",
|
||||||
|
safe_serialization=not export_legacy_format,
|
||||||
|
)
|
||||||
|
if tokenizer is not None:
|
||||||
|
try:
|
||||||
|
if hasattr(tokenizer, "padding_side"):
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
tokenizer.save_pretrained(export_dir)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to save tokenizer: {e}")
|
||||||
|
|
||||||
|
if export_hub_model_id:
|
||||||
|
logger.info_rank0(f"Pushing to hub: {export_hub_model_id}...")
|
||||||
|
model.push_to_hub(export_hub_model_id)
|
||||||
|
if tokenizer is not None:
|
||||||
|
tokenizer.push_to_hub(export_hub_model_id)
|
||||||
|
|
||||||
|
logger.info_rank0("Model exported successfully.")
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from torch.distributed.fsdp import (
|
|||||||
fully_shard,
|
fully_shard,
|
||||||
)
|
)
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
from peft.tuners.lora import LoraLayer
|
||||||
|
|
||||||
from ....accelerator.helper import get_current_accelerator
|
from ....accelerator.helper import get_current_accelerator
|
||||||
from ....accelerator.interface import DistributedInterface
|
from ....accelerator.interface import DistributedInterface
|
||||||
@@ -93,6 +94,10 @@ class FSDP2Engine:
|
|||||||
reduce_dtype=reduce_dtype,
|
reduce_dtype=reduce_dtype,
|
||||||
cast_forward_inputs=True,
|
cast_forward_inputs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_lora_module_wrap(self, model) -> bool:
|
||||||
|
return any(isinstance(module, LoraLayer) for module in model.modules())
|
||||||
|
|
||||||
def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel:
|
def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel:
|
||||||
if self.fsdp_mesh is None:
|
if self.fsdp_mesh is None:
|
||||||
@@ -110,6 +115,26 @@ class FSDP2Engine:
|
|||||||
else:
|
else:
|
||||||
logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}")
|
logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}")
|
||||||
transformer_layer_cls_to_wrap = {layer_cls}
|
transformer_layer_cls_to_wrap = {layer_cls}
|
||||||
|
|
||||||
|
if self.is_lora_module_wrap(model):
|
||||||
|
lora_modules = []
|
||||||
|
for module in model.modules():
|
||||||
|
|
||||||
|
if len(list(module.children())) != 0:
|
||||||
|
continue
|
||||||
|
if any(param.requires_grad for param in module.parameters(recurse=False)):
|
||||||
|
lora_modules.append(module)
|
||||||
|
|
||||||
|
for module in lora_modules:
|
||||||
|
fully_shard(
|
||||||
|
module,
|
||||||
|
mesh=self.fsdp_mesh,
|
||||||
|
reshard_after_forward=self.reshard_after_forward,
|
||||||
|
mp_policy=mp_policy,
|
||||||
|
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Applying FSDP wrap for LoRA layer separately.")
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
should_wrap = False
|
should_wrap = False
|
||||||
@@ -154,7 +179,6 @@ class FSDP2Engine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None):
|
def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None):
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ def run_sft(args: InputArgument = None):
|
|||||||
model_args, data_args, training_args, _ = get_args(args)
|
model_args, data_args, training_args, _ = get_args(args)
|
||||||
DistributedInterface(training_args.dist_config)
|
DistributedInterface(training_args.dist_config)
|
||||||
train_dataset = DataEngine(data_args.train_dataset)
|
train_dataset = DataEngine(data_args.train_dataset)
|
||||||
model_engine = ModelEngine(model_args)
|
model_engine = ModelEngine(model_args, is_train=True)
|
||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
args=training_args,
|
args=training_args,
|
||||||
model=model_engine.model,
|
model=model_engine.model,
|
||||||
|
|||||||
156
tests_v1/plugins/model_plugins/test_peft.py
Normal file
156
tests_v1/plugins/model_plugins/test_peft.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from peft import LoraConfig, PeftModel, get_peft_model
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from llamafactory.v1.plugins.model_plugins import peft as peft_module
|
||||||
|
from llamafactory.v1.plugins.model_plugins.peft import merge_and_export_model
|
||||||
|
|
||||||
|
|
||||||
|
TINY_MODEL = "llamafactory/tiny-random-qwen3"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def model_path():
|
||||||
|
return TINY_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def model(model_path):
|
||||||
|
return AutoModelForCausalLM.from_pretrained(model_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def tokenizer(model_path):
|
||||||
|
return AutoTokenizer.from_pretrained(model_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def adapter_path(tmp_path):
|
||||||
|
# Create a dummy adapter
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=8,
|
||||||
|
lora_alpha=16,
|
||||||
|
target_modules=["q_proj", "v_proj"],
|
||||||
|
lora_dropout=0.05,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
|
||||||
|
base_model = AutoModelForCausalLM.from_pretrained(TINY_MODEL)
|
||||||
|
peft_model = get_peft_model(base_model, lora_config)
|
||||||
|
save_path = tmp_path / "test_adapter"
|
||||||
|
peft_model.save_pretrained(save_path)
|
||||||
|
return str(save_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_all_linear_modules(model):
|
||||||
|
"""Verify linear modules are discoverable and include q_proj / v_proj for tiny-random-qwen3."""
|
||||||
|
modules = peft_module._find_all_linear_modules(model)
|
||||||
|
expected_subset = {"q_proj", "v_proj"}
|
||||||
|
assert expected_subset.issubset(set(modules))
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_lora_model(model):
|
||||||
|
"""Verify a PeftModel is returned and LoRA config takes effect."""
|
||||||
|
config = {"name": "lora", "r": 8, "target_modules": "all", "lora_alpha": 16}
|
||||||
|
model = peft_module.get_lora_model(model, config, is_train=True)
|
||||||
|
assert isinstance(model, PeftModel)
|
||||||
|
assert model.peft_config["default"].r == 8
|
||||||
|
assert "q_proj" in model.peft_config["default"].target_modules
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_freeze_model_layers(model):
|
||||||
|
"""Verify layer-wise freezing: only the last layer stays trainable."""
|
||||||
|
# Freeze all but last layer
|
||||||
|
config = {"name": "freeze", "freeze_trainable_layers": 1, "freeze_trainable_modules": "all"}
|
||||||
|
|
||||||
|
# Ensure we start with something known
|
||||||
|
model = peft_module.get_freeze_model(model, config, is_train=True)
|
||||||
|
|
||||||
|
num_layers = model.config.num_hidden_layers
|
||||||
|
assert num_layers > 0
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if f"layers.{num_layers - 1}" in name:
|
||||||
|
assert param.requires_grad, f"{name} should be trainable"
|
||||||
|
elif "layers.0" in name and num_layers > 1:
|
||||||
|
assert not param.requires_grad, f"{name} should be frozen"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_freeze_model_modules(model):
|
||||||
|
"""Verify module-wise freezing: only last-layer self_attn is trainable."""
|
||||||
|
# Freeze specific modules (e.g. only self_attn)
|
||||||
|
config = {"name": "freeze", "freeze_trainable_layers": 1, "freeze_trainable_modules": "self_attn"}
|
||||||
|
model = peft_module.get_freeze_model(model, config, is_train=True)
|
||||||
|
|
||||||
|
num_layers = model.config.num_hidden_layers
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if f"layers.{num_layers - 1}" in name and "self_attn" in name:
|
||||||
|
assert param.requires_grad, f"{name} should be trainable"
|
||||||
|
else:
|
||||||
|
assert not param.requires_grad, f"{name} should be frozen"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_adapter_single_for_inference(model, adapter_path):
|
||||||
|
"""Verify single adapter is merged+unloaded in inference mode."""
|
||||||
|
# Test loading single adapter for inference (merge and unload)
|
||||||
|
model_result = peft_module.load_adapter(model, adapter_path, is_train=False)
|
||||||
|
assert not isinstance(model_result, PeftModel)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_adapter_resume_train(model, adapter_path):
|
||||||
|
"""Verify training mode returns a trainable PeftModel."""
|
||||||
|
# Test loading for training
|
||||||
|
model_result = peft_module.load_adapter(model, adapter_path, is_train=True)
|
||||||
|
assert isinstance(model_result, PeftModel)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_adapter_train_multiple_disallowed(model, adapter_path):
|
||||||
|
"""Verify multiple adapters are rejected in training mode."""
|
||||||
|
with pytest.raises(ValueError, match="only a single LoRA adapter"):
|
||||||
|
peft_module.load_adapter(model, [adapter_path, adapter_path], is_train=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_adapter_infer_multiple_merges(model, adapter_path):
|
||||||
|
"""Verify multiple adapters are merged in inference mode."""
|
||||||
|
# Test merging multiple adapters
|
||||||
|
model_result = peft_module.load_adapter(model, [adapter_path, adapter_path], is_train=False)
|
||||||
|
assert not isinstance(model_result, PeftModel)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_and_export_model(tmp_path, adapter_path):
|
||||||
|
"""Verify merge_and_export_model produces export artifacts."""
|
||||||
|
export_dir = tmp_path / "export"
|
||||||
|
|
||||||
|
args_dict = {
|
||||||
|
"model": TINY_MODEL,
|
||||||
|
"peft_config": {
|
||||||
|
"name": "lora",
|
||||||
|
"adapter_name_or_path": adapter_path,
|
||||||
|
"export_dir": str(export_dir),
|
||||||
|
"export_size": 1,
|
||||||
|
"infer_dtype": "float16",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
merge_and_export_model(args_dict)
|
||||||
|
|
||||||
|
assert export_dir.exists()
|
||||||
|
assert (export_dir / "config.json").exists()
|
||||||
|
assert (export_dir / "model.safetensors").exists()
|
||||||
|
assert (export_dir / "tokenizer_config.json").exists()
|
||||||
Reference in New Issue
Block a user